user discovery database store works

This commit is contained in:
otsmr 2026-04-17 00:22:38 +02:00
parent 252e7653db
commit eb22acacee
36 changed files with 18939 additions and 666 deletions

View file

@ -8,7 +8,7 @@ import 'package:flutter_rust_bridge/flutter_rust_bridge_for_generated.dart';
import 'database/contact.dart';
import 'frb_generated.dart';
// These functions are ignored because they are not marked as `pub`: `get_instance`
// These functions are ignored because they are not marked as `pub`: `get_workspace`
// These types are ignored because they are neither used by any `pub` functions nor (for structs and enums) marked `#[frb(unignore)]`: `Twonly`
Future<void> initializeTwonly({required TwonlyConfig config}) =>
@ -30,7 +30,7 @@ class OtherPromotion {
this.publicKeyVerifiedTimestamp,
});
final int promotionId;
final BigInt publicId;
final PlatformInt64 publicId;
final PlatformInt64 fromContactId;
final int threshold;
final Uint8List announcementShare;
@ -61,16 +61,19 @@ class OtherPromotion {
class TwonlyConfig {
const TwonlyConfig({
required this.databasePath,
required this.dataDirectory,
});
final String databasePath;
final String dataDirectory;
@override
int get hashCode => databasePath.hashCode;
int get hashCode => databasePath.hashCode ^ dataDirectory.hashCode;
@override
bool operator ==(Object other) =>
identical(this, other) ||
other is TwonlyConfig &&
runtimeType == other.runtimeType &&
databasePath == other.databasePath;
databasePath == other.databasePath &&
dataDirectory == other.dataDirectory;
}

View file

@ -0,0 +1,25 @@
// This file is automatically generated, so please do not edit it.
// @generated by `flutter_rust_bridge`@ 2.12.0.
// ignore_for_file: invalid_use_of_internal_member, unused_import
import 'package:flutter_rust_bridge/flutter_rust_bridge_for_generated.dart';
import '../frb_generated.dart';
// These function are ignored because they are on traits that is not defined in current crate (put an empty `#[frb]` on it to unignore): `clone`, `get_all_announced_users`, `get_announced_user_by_public_id`, `get_config`, `get_contact_version`, `get_other_promotions_by_public_id`, `get_own_promotions_after_version`, `get_share_for_contact`, `push_new_user_relation`, `push_own_promotion`, `set_contact_version`, `set_shares`, `store_other_promotion`, `update_config`
class UserDiscoveryDatabaseStore {
const UserDiscoveryDatabaseStore();
static Future<UserDiscoveryDatabaseStore> default_() => RustLib.instance.api
.crateBridgeUserDiscoveryUserDiscoveryDatabaseStoreDefault();
@override
int get hashCode => 0;
@override
bool operator ==(Object other) =>
identical(this, other) ||
other is UserDiscoveryDatabaseStore && runtimeType == other.runtimeType;
}

View file

@ -253,7 +253,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
throw Exception('unexpected arr length: expect 6 but see ${arr.length}');
return OtherPromotion(
promotionId: dco_decode_u_32(arr[0]),
publicId: dco_decode_u_64(arr[1]),
publicId: dco_decode_i_64(arr[1]),
fromContactId: dco_decode_i_64(arr[2]),
threshold: dco_decode_u_8(arr[3]),
announcementShare: dco_decode_list_prim_u_8_strict(arr[4]),
@ -265,10 +265,11 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
TwonlyConfig dco_decode_twonly_config(dynamic raw) {
// Codec=Dco (DartCObject based), see doc to use other codecs
final arr = raw as List<dynamic>;
if (arr.length != 1)
throw Exception('unexpected arr length: expect 1 but see ${arr.length}');
if (arr.length != 2)
throw Exception('unexpected arr length: expect 2 but see ${arr.length}');
return TwonlyConfig(
databasePath: dco_decode_String(arr[0]),
dataDirectory: dco_decode_String(arr[1]),
);
}
@ -278,12 +279,6 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
return raw as int;
}
@protected
BigInt dco_decode_u_64(dynamic raw) {
// Codec=Dco (DartCObject based), see doc to use other codecs
return dcoDecodeU64(raw);
}
@protected
int dco_decode_u_8(dynamic raw) {
// Codec=Dco (DartCObject based), see doc to use other codecs
@ -372,7 +367,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
OtherPromotion sse_decode_other_promotion(SseDeserializer deserializer) {
// Codec=Sse (Serialization based), see doc to use other codecs
final var_promotionId = sse_decode_u_32(deserializer);
final var_publicId = sse_decode_u_64(deserializer);
final var_publicId = sse_decode_i_64(deserializer);
final var_fromContactId = sse_decode_i_64(deserializer);
final var_threshold = sse_decode_u_8(deserializer);
final var_announcementShare = sse_decode_list_prim_u_8_strict(deserializer);
@ -393,7 +388,11 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
TwonlyConfig sse_decode_twonly_config(SseDeserializer deserializer) {
// Codec=Sse (Serialization based), see doc to use other codecs
final var_databasePath = sse_decode_String(deserializer);
return TwonlyConfig(databasePath: var_databasePath);
final var_dataDirectory = sse_decode_String(deserializer);
return TwonlyConfig(
databasePath: var_databasePath,
dataDirectory: var_dataDirectory,
);
}
@protected
@ -402,12 +401,6 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
return deserializer.buffer.getUint32();
}
@protected
BigInt sse_decode_u_64(SseDeserializer deserializer) {
// Codec=Sse (Serialization based), see doc to use other codecs
return deserializer.buffer.getBigUint64();
}
@protected
int sse_decode_u_8(SseDeserializer deserializer) {
// Codec=Sse (Serialization based), see doc to use other codecs
@ -516,7 +509,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
) {
// Codec=Sse (Serialization based), see doc to use other codecs
sse_encode_u_32(self.promotionId, serializer);
sse_encode_u_64(self.publicId, serializer);
sse_encode_i_64(self.publicId, serializer);
sse_encode_i_64(self.fromContactId, serializer);
sse_encode_u_8(self.threshold, serializer);
sse_encode_list_prim_u_8_strict(self.announcementShare, serializer);
@ -530,6 +523,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
void sse_encode_twonly_config(TwonlyConfig self, SseSerializer serializer) {
// Codec=Sse (Serialization based), see doc to use other codecs
sse_encode_String(self.databasePath, serializer);
sse_encode_String(self.dataDirectory, serializer);
}
@protected
@ -538,12 +532,6 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
serializer.buffer.putUint32(self);
}
@protected
void sse_encode_u_64(BigInt self, SseSerializer serializer) {
// Codec=Sse (Serialization based), see doc to use other codecs
serializer.buffer.putBigUint64(self);
}
@protected
void sse_encode_u_8(int self, SseSerializer serializer) {
// Codec=Sse (Serialization based), see doc to use other codecs

View file

@ -57,9 +57,6 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
@protected
int dco_decode_u_32(dynamic raw);
@protected
BigInt dco_decode_u_64(dynamic raw);
@protected
int dco_decode_u_8(dynamic raw);
@ -104,9 +101,6 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
@protected
int sse_decode_u_32(SseDeserializer deserializer);
@protected
BigInt sse_decode_u_64(SseDeserializer deserializer);
@protected
int sse_decode_u_8(SseDeserializer deserializer);
@ -173,9 +167,6 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
@protected
void sse_encode_u_32(int self, SseSerializer serializer);
@protected
void sse_encode_u_64(BigInt self, SseSerializer serializer);
@protected
void sse_encode_u_8(int self, SseSerializer serializer);

View file

@ -58,9 +58,6 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
@protected
int dco_decode_u_32(dynamic raw);
@protected
BigInt dco_decode_u_64(dynamic raw);
@protected
int dco_decode_u_8(dynamic raw);
@ -105,9 +102,6 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
@protected
int sse_decode_u_32(SseDeserializer deserializer);
@protected
BigInt sse_decode_u_64(SseDeserializer deserializer);
@protected
int sse_decode_u_8(SseDeserializer deserializer);
@ -174,9 +168,6 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl<RustLibWire> {
@protected
void sse_encode_u_32(int self, SseSerializer serializer);
@protected
void sse_encode_u_64(BigInt self, SseSerializer serializer);
@protected
void sse_encode_u_8(int self, SseSerializer serializer);

View file

@ -40,6 +40,7 @@ void main() async {
await bridge.initializeTwonly(
config: bridge.TwonlyConfig(
databasePath: '$globalApplicationSupportDirectory/twonly.sqlite',
dataDirectory: globalApplicationSupportDirectory,
),
);

View file

@ -0,0 +1,21 @@
import 'package:drift/drift.dart';
import 'package:twonly/src/database/tables/user_discovery.table.dart';
import 'package:twonly/src/database/twonly.db.dart';
part 'user_discovery.dao.g.dart';
@DriftAccessor(
tables: [
UserDiscoveryAnnouncedUsers,
UserDiscoveryUserRelations,
UserDiscoveryOwnPromotions,
UserDiscoveryShares,
],
)
class UserDiscoveryDao extends DatabaseAccessor<TwonlyDB>
with _$UserDiscoveryDaoMixin {
// this constructor is required so that the main database can create an instance
// of this object.
// ignore: matching_super_parameters
UserDiscoveryDao(super.db);
}

View file

@ -0,0 +1,47 @@
// GENERATED CODE - DO NOT MODIFY BY HAND
part of 'user_discovery.dao.dart';
// ignore_for_file: type=lint
mixin _$UserDiscoveryDaoMixin on DatabaseAccessor<TwonlyDB> {
$UserDiscoveryAnnouncedUsersTable get userDiscoveryAnnouncedUsers =>
attachedDatabase.userDiscoveryAnnouncedUsers;
$ContactsTable get contacts => attachedDatabase.contacts;
$UserDiscoveryUserRelationsTable get userDiscoveryUserRelations =>
attachedDatabase.userDiscoveryUserRelations;
$UserDiscoveryOwnPromotionsTable get userDiscoveryOwnPromotions =>
attachedDatabase.userDiscoveryOwnPromotions;
$UserDiscoverySharesTable get userDiscoveryShares =>
attachedDatabase.userDiscoveryShares;
UserDiscoveryDaoManager get managers => UserDiscoveryDaoManager(this);
}
class UserDiscoveryDaoManager {
final _$UserDiscoveryDaoMixin _db;
UserDiscoveryDaoManager(this._db);
$$UserDiscoveryAnnouncedUsersTableTableManager
get userDiscoveryAnnouncedUsers =>
$$UserDiscoveryAnnouncedUsersTableTableManager(
_db.attachedDatabase,
_db.userDiscoveryAnnouncedUsers,
);
$$ContactsTableTableManager get contacts =>
$$ContactsTableTableManager(_db.attachedDatabase, _db.contacts);
$$UserDiscoveryUserRelationsTableTableManager
get userDiscoveryUserRelations =>
$$UserDiscoveryUserRelationsTableTableManager(
_db.attachedDatabase,
_db.userDiscoveryUserRelations,
);
$$UserDiscoveryOwnPromotionsTableTableManager
get userDiscoveryOwnPromotions =>
$$UserDiscoveryOwnPromotionsTableTableManager(
_db.attachedDatabase,
_db.userDiscoveryOwnPromotions,
);
$$UserDiscoverySharesTableTableManager get userDiscoveryShares =>
$$UserDiscoverySharesTableTableManager(
_db.attachedDatabase,
_db.userDiscoveryShares,
);
}

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,6 @@
import 'package:drift/drift.dart';
@DataClassName('Contact')
class Contacts extends Table {
IntColumn get userId => integer()();
@ -22,6 +23,37 @@ class Contacts extends Table {
DateTimeColumn get createdAt => dateTime().withDefault(currentDateAndTime)();
// contact_versions: HashMap<UserID, Vec<u8>>,
BlobColumn get userDiscoveryVersion => blob().nullable()();
@override
Set<Column> get primaryKey => {userId};
}
enum VerificationType {
qr,
link,
}
@DataClassName('KeyVerification')
class KeyVerifications extends Table {
IntColumn get contactId => integer().references(
Contacts,
#userId,
onDelete: KeyAction.cascade,
)();
TextColumn get type => textEnum<VerificationType>()();
DateTimeColumn get createdAt => dateTime().withDefault(currentDateAndTime)();
@override
Set<Column> get primaryKey => {contactId};
}
@DataClassName('VerificationToken')
class VerificationTokens extends Table {
IntColumn get tokenId => integer().autoIncrement()();
BlobColumn get token => blob()();
DateTimeColumn get createdAt => dateTime().withDefault(currentDateAndTime)();
}

View file

@ -1,12 +1,39 @@
import 'package:drift/drift.dart';
import 'package:twonly/src/database/tables/contacts.table.dart';
// contact_versions: HashMap<UserID, Vec<u8>>,
// -> New Column in Contacts
// config: Option<Vec<u8>>,
// announced_users: HashMap<AnnouncedUser, Vec<(UserID, Option<i64>)>>,
@DataClassName('UserDiscoveryAnnouncedUser')
class UserDiscoveryAnnouncedUsers extends Table {
IntColumn get announcedUserId => integer()();
BlobColumn get announcedPublicKey => blob()();
IntColumn get publicId => integer().unique()();
@override
Set<Column> get primaryKey => {announcedUserId};
}
// announced_users: HashMap<AnnouncedUser, Vec<(UserID, Option<i64>)>>,
@DataClassName('UserDiscoveryUserRelation')
class UserDiscoveryUserRelations extends Table {
IntColumn get announcedUserId => integer().references(
UserDiscoveryAnnouncedUsers,
#announcedUserId,
onDelete: KeyAction.cascade,
)();
IntColumn get fromContactId => integer().references(
Contacts,
#userId,
onDelete: KeyAction.cascade,
)();
DateTimeColumn get publicKeyVerifiedTimestamp => dateTime().nullable()();
@override
Set<Column> get primaryKey => {announcedUserId, fromContactId};
}
// own_promotions: Vec<(UserID, Vec<u8>)>,
@DataClassName('UserDiscoveryOwnPromotion')
@ -17,21 +44,26 @@ class UserDiscoveryOwnPromotions extends Table {
#userId,
onDelete: KeyAction.cascade,
)();
BlobColumn get promotion => blob()();
}
// other_promotions: Vec<OtherPromotion>,
@DataClassName('UserDiscoveryOtherPromotion')
class UserDiscoveryOtherPromotions extends Table {
IntColumn get versionId => integer().autoIncrement()();
IntColumn get contactId => integer().references(
IntColumn get fromContactId => integer().references(
Contacts,
#userId,
onDelete: KeyAction.cascade,
)();
BlobColumn get promotion => blob()();
IntColumn get promotionId => integer()();
IntColumn get publicId => integer()();
IntColumn get threshold => integer()();
BlobColumn get announcementShare => blob()();
DateTimeColumn get publicKeyVerifiedTimestamp => dateTime().nullable()();
@override
Set<Column> get primaryKey => {fromContactId, promotionId};
}
// unused_shares: Vec<Vec<u8>>,

View file

@ -9,6 +9,7 @@ import 'package:twonly/src/database/daos/mediafiles.dao.dart';
import 'package:twonly/src/database/daos/messages.dao.dart';
import 'package:twonly/src/database/daos/reactions.dao.dart';
import 'package:twonly/src/database/daos/receipts.dao.dart';
import 'package:twonly/src/database/daos/user_discovery.dao.dart';
import 'package:twonly/src/database/tables/contacts.table.dart';
import 'package:twonly/src/database/tables/groups.table.dart';
import 'package:twonly/src/database/tables/mediafiles.table.dart';
@ -19,6 +20,7 @@ import 'package:twonly/src/database/tables/signal_identity_key_store.table.dart'
import 'package:twonly/src/database/tables/signal_pre_key_store.table.dart';
import 'package:twonly/src/database/tables/signal_sender_key_store.table.dart';
import 'package:twonly/src/database/tables/signal_session_store.table.dart';
import 'package:twonly/src/database/tables/user_discovery.table.dart';
import 'package:twonly/src/database/twonly.db.steps.dart';
import 'package:twonly/src/utils/log.dart';
@ -42,6 +44,13 @@ part 'twonly.db.g.dart';
SignalSessionStores,
MessageActions,
GroupHistories,
KeyVerifications,
VerificationTokens,
UserDiscoveryAnnouncedUsers,
UserDiscoveryUserRelations,
UserDiscoveryOtherPromotions,
UserDiscoveryOwnPromotions,
UserDiscoveryShares,
],
daos: [
MessagesDao,
@ -50,6 +59,7 @@ part 'twonly.db.g.dart';
GroupsDao,
ReactionsDao,
MediaFilesDao,
UserDiscoveryDao,
],
)
class TwonlyDB extends _$TwonlyDB {
@ -62,7 +72,7 @@ class TwonlyDB extends _$TwonlyDB {
TwonlyDB.forTesting(DatabaseConnection super.connection);
@override
int get schemaVersion => 11;
int get schemaVersion => 12;
static QueryExecutor _openConnection() {
return driftDatabase(
@ -158,6 +168,19 @@ class TwonlyDB extends _$TwonlyDB {
schema.groupMembers.lastTypeIndicator,
);
},
from11To12: (m, schema) async {
await m.createTable(schema.verificationTokens);
await m.createTable(schema.keyVerifications);
await m.createTable(schema.userDiscoveryAnnouncedUsers);
await m.createTable(schema.userDiscoveryOwnPromotions);
await m.createTable(schema.userDiscoveryOtherPromotions);
await m.createTable(schema.userDiscoveryShares);
await m.createTable(schema.userDiscoveryUserRelations);
await m.addColumn(
schema.contacts,
schema.contacts.userDiscoveryVersion,
);
},
)(m, from, to);
},
);

File diff suppressed because it is too large Load diff

View file

@ -5823,6 +5823,662 @@ i1.GeneratedColumn<int> _column_210(String aliasedName) =>
type: i1.DriftSqlType.int,
$customConstraints: 'NULL',
);
final class Schema12 extends i0.VersionedSchema {
Schema12({required super.database}) : super(version: 12);
@override
late final List<i1.DatabaseSchemaEntity> entities = [
contacts,
groups,
mediaFiles,
messages,
messageHistories,
reactions,
groupMembers,
receipts,
receivedReceipts,
signalIdentityKeyStores,
signalPreKeyStores,
signalSenderKeyStores,
signalSessionStores,
messageActions,
groupHistories,
keyVerifications,
verificationTokens,
userDiscoveryAnnouncedUsers,
userDiscoveryUserRelations,
userDiscoveryOtherPromotions,
userDiscoveryOwnPromotions,
userDiscoveryShares,
];
late final Shape39 contacts = Shape39(
source: i0.VersionedTable(
entityName: 'contacts',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(user_id)'],
columns: [
_column_106,
_column_107,
_column_108,
_column_109,
_column_110,
_column_111,
_column_112,
_column_113,
_column_114,
_column_115,
_column_116,
_column_117,
_column_118,
_column_211,
],
attachedDatabase: database,
),
alias: null,
);
late final Shape23 groups = Shape23(
source: i0.VersionedTable(
entityName: 'groups',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(group_id)'],
columns: [
_column_119,
_column_120,
_column_121,
_column_122,
_column_123,
_column_124,
_column_125,
_column_126,
_column_127,
_column_128,
_column_129,
_column_130,
_column_131,
_column_132,
_column_133,
_column_134,
_column_118,
_column_135,
_column_136,
_column_137,
_column_138,
_column_139,
_column_140,
_column_141,
_column_142,
],
attachedDatabase: database,
),
alias: null,
);
late final Shape36 mediaFiles = Shape36(
source: i0.VersionedTable(
entityName: 'media_files',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(media_id)'],
columns: [
_column_143,
_column_144,
_column_145,
_column_146,
_column_147,
_column_148,
_column_149,
_column_207,
_column_150,
_column_151,
_column_152,
_column_153,
_column_154,
_column_155,
_column_156,
_column_157,
_column_118,
],
attachedDatabase: database,
),
alias: null,
);
late final Shape25 messages = Shape25(
source: i0.VersionedTable(
entityName: 'messages',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(message_id)'],
columns: [
_column_158,
_column_159,
_column_160,
_column_144,
_column_161,
_column_162,
_column_163,
_column_164,
_column_165,
_column_153,
_column_166,
_column_167,
_column_168,
_column_169,
_column_118,
_column_170,
_column_171,
_column_172,
],
attachedDatabase: database,
),
alias: null,
);
late final Shape26 messageHistories = Shape26(
source: i0.VersionedTable(
entityName: 'message_histories',
withoutRowId: false,
isStrict: false,
tableConstraints: [],
columns: [
_column_173,
_column_174,
_column_175,
_column_161,
_column_118,
],
attachedDatabase: database,
),
alias: null,
);
late final Shape27 reactions = Shape27(
source: i0.VersionedTable(
entityName: 'reactions',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(message_id, sender_id, emoji)'],
columns: [_column_174, _column_176, _column_177, _column_118],
attachedDatabase: database,
),
alias: null,
);
late final Shape38 groupMembers = Shape38(
source: i0.VersionedTable(
entityName: 'group_members',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(group_id, contact_id)'],
columns: [
_column_158,
_column_178,
_column_179,
_column_180,
_column_209,
_column_210,
_column_181,
_column_118,
],
attachedDatabase: database,
),
alias: null,
);
late final Shape37 receipts = Shape37(
source: i0.VersionedTable(
entityName: 'receipts',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(receipt_id)'],
columns: [
_column_182,
_column_183,
_column_184,
_column_185,
_column_186,
_column_208,
_column_187,
_column_188,
_column_189,
_column_190,
_column_191,
_column_118,
],
attachedDatabase: database,
),
alias: null,
);
late final Shape30 receivedReceipts = Shape30(
source: i0.VersionedTable(
entityName: 'received_receipts',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(receipt_id)'],
columns: [_column_182, _column_118],
attachedDatabase: database,
),
alias: null,
);
late final Shape31 signalIdentityKeyStores = Shape31(
source: i0.VersionedTable(
entityName: 'signal_identity_key_stores',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(device_id, name)'],
columns: [_column_192, _column_193, _column_194, _column_118],
attachedDatabase: database,
),
alias: null,
);
late final Shape32 signalPreKeyStores = Shape32(
source: i0.VersionedTable(
entityName: 'signal_pre_key_stores',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(pre_key_id)'],
columns: [_column_195, _column_196, _column_118],
attachedDatabase: database,
),
alias: null,
);
late final Shape11 signalSenderKeyStores = Shape11(
source: i0.VersionedTable(
entityName: 'signal_sender_key_stores',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(sender_key_name)'],
columns: [_column_197, _column_198],
attachedDatabase: database,
),
alias: null,
);
late final Shape33 signalSessionStores = Shape33(
source: i0.VersionedTable(
entityName: 'signal_session_stores',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(device_id, name)'],
columns: [_column_192, _column_193, _column_199, _column_118],
attachedDatabase: database,
),
alias: null,
);
late final Shape34 messageActions = Shape34(
source: i0.VersionedTable(
entityName: 'message_actions',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(message_id, contact_id, type)'],
columns: [_column_174, _column_183, _column_144, _column_200],
attachedDatabase: database,
),
alias: null,
);
late final Shape35 groupHistories = Shape35(
source: i0.VersionedTable(
entityName: 'group_histories',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(group_history_id)'],
columns: [
_column_201,
_column_158,
_column_202,
_column_203,
_column_204,
_column_205,
_column_206,
_column_144,
_column_200,
],
attachedDatabase: database,
),
alias: null,
);
late final Shape40 keyVerifications = Shape40(
source: i0.VersionedTable(
entityName: 'key_verifications',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(contact_id)'],
columns: [_column_183, _column_144, _column_118],
attachedDatabase: database,
),
alias: null,
);
late final Shape41 verificationTokens = Shape41(
source: i0.VersionedTable(
entityName: 'verification_tokens',
withoutRowId: false,
isStrict: false,
tableConstraints: [],
columns: [_column_212, _column_213, _column_118],
attachedDatabase: database,
),
alias: null,
);
late final Shape42 userDiscoveryAnnouncedUsers = Shape42(
source: i0.VersionedTable(
entityName: 'user_discovery_announced_users',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(announced_user_id)'],
columns: [_column_214, _column_215, _column_216],
attachedDatabase: database,
),
alias: null,
);
late final Shape43 userDiscoveryUserRelations = Shape43(
source: i0.VersionedTable(
entityName: 'user_discovery_user_relations',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(announced_user_id, from_contact_id)'],
columns: [_column_217, _column_218, _column_219],
attachedDatabase: database,
),
alias: null,
);
late final Shape44 userDiscoveryOtherPromotions = Shape44(
source: i0.VersionedTable(
entityName: 'user_discovery_other_promotions',
withoutRowId: false,
isStrict: false,
tableConstraints: ['PRIMARY KEY(from_contact_id, promotion_id)'],
columns: [
_column_218,
_column_220,
_column_221,
_column_222,
_column_223,
_column_219,
],
attachedDatabase: database,
),
alias: null,
);
late final Shape45 userDiscoveryOwnPromotions = Shape45(
source: i0.VersionedTable(
entityName: 'user_discovery_own_promotions',
withoutRowId: false,
isStrict: false,
tableConstraints: [],
columns: [_column_224, _column_183, _column_225],
attachedDatabase: database,
),
alias: null,
);
late final Shape46 userDiscoveryShares = Shape46(
source: i0.VersionedTable(
entityName: 'user_discovery_shares',
withoutRowId: false,
isStrict: false,
tableConstraints: [],
columns: [_column_226, _column_227, _column_175],
attachedDatabase: database,
),
alias: null,
);
}
class Shape39 extends i0.VersionedTable {
Shape39({required super.source, required super.alias}) : super.aliased();
i1.GeneratedColumn<int> get userId =>
columnsByName['user_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<String> get username =>
columnsByName['username']! as i1.GeneratedColumn<String>;
i1.GeneratedColumn<String> get displayName =>
columnsByName['display_name']! as i1.GeneratedColumn<String>;
i1.GeneratedColumn<String> get nickName =>
columnsByName['nick_name']! as i1.GeneratedColumn<String>;
i1.GeneratedColumn<i2.Uint8List> get avatarSvgCompressed =>
columnsByName['avatar_svg_compressed']!
as i1.GeneratedColumn<i2.Uint8List>;
i1.GeneratedColumn<int> get senderProfileCounter =>
columnsByName['sender_profile_counter']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get accepted =>
columnsByName['accepted']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get deletedByUser =>
columnsByName['deleted_by_user']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get requested =>
columnsByName['requested']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get blocked =>
columnsByName['blocked']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get verified =>
columnsByName['verified']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get accountDeleted =>
columnsByName['account_deleted']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get createdAt =>
columnsByName['created_at']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<i2.Uint8List> get userDiscoveryVersion =>
columnsByName['user_discovery_version']!
as i1.GeneratedColumn<i2.Uint8List>;
}
i1.GeneratedColumn<i2.Uint8List> _column_211(String aliasedName) =>
i1.GeneratedColumn<i2.Uint8List>(
'user_discovery_version',
aliasedName,
true,
type: i1.DriftSqlType.blob,
$customConstraints: 'NULL',
);
class Shape40 extends i0.VersionedTable {
Shape40({required super.source, required super.alias}) : super.aliased();
i1.GeneratedColumn<int> get contactId =>
columnsByName['contact_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<String> get type =>
columnsByName['type']! as i1.GeneratedColumn<String>;
i1.GeneratedColumn<int> get createdAt =>
columnsByName['created_at']! as i1.GeneratedColumn<int>;
}
class Shape41 extends i0.VersionedTable {
Shape41({required super.source, required super.alias}) : super.aliased();
i1.GeneratedColumn<int> get tokenId =>
columnsByName['token_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<i2.Uint8List> get token =>
columnsByName['token']! as i1.GeneratedColumn<i2.Uint8List>;
i1.GeneratedColumn<int> get createdAt =>
columnsByName['created_at']! as i1.GeneratedColumn<int>;
}
i1.GeneratedColumn<int> _column_212(String aliasedName) =>
i1.GeneratedColumn<int>(
'token_id',
aliasedName,
false,
hasAutoIncrement: true,
type: i1.DriftSqlType.int,
$customConstraints: 'NOT NULL PRIMARY KEY AUTOINCREMENT',
);
i1.GeneratedColumn<i2.Uint8List> _column_213(String aliasedName) =>
i1.GeneratedColumn<i2.Uint8List>(
'token',
aliasedName,
false,
type: i1.DriftSqlType.blob,
$customConstraints: 'NOT NULL',
);
class Shape42 extends i0.VersionedTable {
Shape42({required super.source, required super.alias}) : super.aliased();
i1.GeneratedColumn<int> get announcedUserId =>
columnsByName['announced_user_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<i2.Uint8List> get announcedPublicKey =>
columnsByName['announced_public_key']!
as i1.GeneratedColumn<i2.Uint8List>;
i1.GeneratedColumn<int> get publicId =>
columnsByName['public_id']! as i1.GeneratedColumn<int>;
}
i1.GeneratedColumn<int> _column_214(String aliasedName) =>
i1.GeneratedColumn<int>(
'announced_user_id',
aliasedName,
false,
type: i1.DriftSqlType.int,
$customConstraints: 'NOT NULL',
);
i1.GeneratedColumn<i2.Uint8List> _column_215(String aliasedName) =>
i1.GeneratedColumn<i2.Uint8List>(
'announced_public_key',
aliasedName,
false,
type: i1.DriftSqlType.blob,
$customConstraints: 'NOT NULL',
);
i1.GeneratedColumn<int> _column_216(String aliasedName) =>
i1.GeneratedColumn<int>(
'public_id',
aliasedName,
false,
type: i1.DriftSqlType.int,
$customConstraints: 'NOT NULL UNIQUE',
);
class Shape43 extends i0.VersionedTable {
Shape43({required super.source, required super.alias}) : super.aliased();
i1.GeneratedColumn<int> get announcedUserId =>
columnsByName['announced_user_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get fromContactId =>
columnsByName['from_contact_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get publicKeyVerifiedTimestamp =>
columnsByName['public_key_verified_timestamp']!
as i1.GeneratedColumn<int>;
}
i1.GeneratedColumn<int> _column_217(
String aliasedName,
) => i1.GeneratedColumn<int>(
'announced_user_id',
aliasedName,
false,
type: i1.DriftSqlType.int,
$customConstraints:
'NOT NULL REFERENCES user_discovery_announced_users(announced_user_id)ON DELETE CASCADE',
);
i1.GeneratedColumn<int> _column_218(String aliasedName) =>
i1.GeneratedColumn<int>(
'from_contact_id',
aliasedName,
false,
type: i1.DriftSqlType.int,
$customConstraints:
'NOT NULL REFERENCES contacts(user_id)ON DELETE CASCADE',
);
i1.GeneratedColumn<int> _column_219(String aliasedName) =>
i1.GeneratedColumn<int>(
'public_key_verified_timestamp',
aliasedName,
true,
type: i1.DriftSqlType.int,
$customConstraints: 'NULL',
);
class Shape44 extends i0.VersionedTable {
Shape44({required super.source, required super.alias}) : super.aliased();
i1.GeneratedColumn<int> get fromContactId =>
columnsByName['from_contact_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get promotionId =>
columnsByName['promotion_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get publicId =>
columnsByName['public_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get threshold =>
columnsByName['threshold']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<i2.Uint8List> get announcementShare =>
columnsByName['announcement_share']! as i1.GeneratedColumn<i2.Uint8List>;
i1.GeneratedColumn<int> get publicKeyVerifiedTimestamp =>
columnsByName['public_key_verified_timestamp']!
as i1.GeneratedColumn<int>;
}
i1.GeneratedColumn<int> _column_220(String aliasedName) =>
i1.GeneratedColumn<int>(
'promotion_id',
aliasedName,
false,
type: i1.DriftSqlType.int,
$customConstraints: 'NOT NULL',
);
i1.GeneratedColumn<int> _column_221(String aliasedName) =>
i1.GeneratedColumn<int>(
'public_id',
aliasedName,
false,
type: i1.DriftSqlType.int,
$customConstraints: 'NOT NULL',
);
i1.GeneratedColumn<int> _column_222(String aliasedName) =>
i1.GeneratedColumn<int>(
'threshold',
aliasedName,
false,
type: i1.DriftSqlType.int,
$customConstraints: 'NOT NULL',
);
i1.GeneratedColumn<i2.Uint8List> _column_223(String aliasedName) =>
i1.GeneratedColumn<i2.Uint8List>(
'announcement_share',
aliasedName,
false,
type: i1.DriftSqlType.blob,
$customConstraints: 'NOT NULL',
);
class Shape45 extends i0.VersionedTable {
Shape45({required super.source, required super.alias}) : super.aliased();
i1.GeneratedColumn<int> get versionId =>
columnsByName['version_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<int> get contactId =>
columnsByName['contact_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<i2.Uint8List> get promotion =>
columnsByName['promotion']! as i1.GeneratedColumn<i2.Uint8List>;
}
i1.GeneratedColumn<int> _column_224(String aliasedName) =>
i1.GeneratedColumn<int>(
'version_id',
aliasedName,
false,
hasAutoIncrement: true,
type: i1.DriftSqlType.int,
$customConstraints: 'NOT NULL PRIMARY KEY AUTOINCREMENT',
);
i1.GeneratedColumn<i2.Uint8List> _column_225(String aliasedName) =>
i1.GeneratedColumn<i2.Uint8List>(
'promotion',
aliasedName,
false,
type: i1.DriftSqlType.blob,
$customConstraints: 'NOT NULL',
);
class Shape46 extends i0.VersionedTable {
Shape46({required super.source, required super.alias}) : super.aliased();
i1.GeneratedColumn<int> get shareId =>
columnsByName['share_id']! as i1.GeneratedColumn<int>;
i1.GeneratedColumn<i2.Uint8List> get share =>
columnsByName['share']! as i1.GeneratedColumn<i2.Uint8List>;
i1.GeneratedColumn<int> get contactId =>
columnsByName['contact_id']! as i1.GeneratedColumn<int>;
}
i1.GeneratedColumn<int> _column_226(String aliasedName) =>
i1.GeneratedColumn<int>(
'share_id',
aliasedName,
false,
hasAutoIncrement: true,
type: i1.DriftSqlType.int,
$customConstraints: 'NOT NULL PRIMARY KEY AUTOINCREMENT',
);
i1.GeneratedColumn<i2.Uint8List> _column_227(String aliasedName) =>
i1.GeneratedColumn<i2.Uint8List>(
'share',
aliasedName,
false,
type: i1.DriftSqlType.blob,
$customConstraints: 'NOT NULL',
);
i0.MigrationStepWithVersion migrationSteps({
required Future<void> Function(i1.Migrator m, Schema2 schema) from1To2,
required Future<void> Function(i1.Migrator m, Schema3 schema) from2To3,
@ -5834,6 +6490,7 @@ i0.MigrationStepWithVersion migrationSteps({
required Future<void> Function(i1.Migrator m, Schema9 schema) from8To9,
required Future<void> Function(i1.Migrator m, Schema10 schema) from9To10,
required Future<void> Function(i1.Migrator m, Schema11 schema) from10To11,
required Future<void> Function(i1.Migrator m, Schema12 schema) from11To12,
}) {
return (currentVersion, database) async {
switch (currentVersion) {
@ -5887,6 +6544,11 @@ i0.MigrationStepWithVersion migrationSteps({
final migrator = i1.Migrator(database, schema);
await from10To11(migrator, schema);
return 11;
case 11:
final schema = Schema12(database: database);
final migrator = i1.Migrator(database, schema);
await from11To12(migrator, schema);
return 12;
default:
throw ArgumentError.value('Unknown migration from $currentVersion');
}
@ -5904,6 +6566,7 @@ i1.OnUpgrade stepByStep({
required Future<void> Function(i1.Migrator m, Schema9 schema) from8To9,
required Future<void> Function(i1.Migrator m, Schema10 schema) from9To10,
required Future<void> Function(i1.Migrator m, Schema11 schema) from10To11,
required Future<void> Function(i1.Migrator m, Schema12 schema) from11To12,
}) => i0.VersionedSchema.stepByStepHelper(
step: migrationSteps(
from1To2: from1To2,
@ -5916,5 +6579,6 @@ i1.OnUpgrade stepByStep({
from8To9: from8To9,
from9To10: from9To10,
from10To11: from10To11,
from11To12: from11To12,
),
);

87
rust/Cargo.lock generated
View file

@ -90,15 +90,6 @@ version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c59bdb34bc650a32731b31bd8f0829cc15d24a708ee31559e0bb34f2bc320cba"
[[package]]
name = "atomic-polyfill"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
dependencies = [
"critical-section",
]
[[package]]
name = "autocfg"
version = "1.5.0"
@ -244,15 +235,6 @@ version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f88a43d011fc4a6876cb7344703e297c71dda42494fee094d5f7c76bf13f746"
[[package]]
name = "cobs"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa961b519f0b462e3a3b4a34b64d119eeaca1d59af726fe450bbba07a9fc0a1"
dependencies = [
"thiserror",
]
[[package]]
name = "concurrent-queue"
version = "2.5.0"
@ -323,12 +305,6 @@ version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5"
[[package]]
name = "critical-section"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
[[package]]
name = "crossbeam-queue"
version = "0.3.12"
@ -466,18 +442,6 @@ dependencies = [
"serde",
]
[[package]]
name = "embedded-io"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced"
[[package]]
name = "embedded-io"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
[[package]]
name = "env_filter"
version = "0.1.4"
@ -771,15 +735,6 @@ version = "0.32.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7"
[[package]]
name = "hash32"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
dependencies = [
"byteorder",
]
[[package]]
name = "hashbrown"
version = "0.14.5"
@ -823,20 +778,6 @@ dependencies = [
"hashbrown 0.15.5",
]
[[package]]
name = "heapless"
version = "0.7.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f"
dependencies = [
"atomic-polyfill",
"hash32",
"rustc_version",
"serde",
"spin",
"stable_deref_trait",
]
[[package]]
name = "heck"
version = "0.5.0"
@ -1390,19 +1331,6 @@ version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
[[package]]
name = "postcard"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6764c3b5dd454e283a30e6dfe78e9b31096d9e32036b5d1eaac7a6119ccb9a24"
dependencies = [
"cobs",
"embedded-io 0.4.0",
"embedded-io 0.6.1",
"heapless",
"serde",
]
[[package]]
name = "potential_utf"
version = "0.1.5"
@ -1508,14 +1436,16 @@ dependencies = [
"base64",
"blahaj",
"hmac 0.13.0",
"postcard",
"pretty_env_logger",
"prost",
"prost-build",
"rand 0.10.1",
"serde",
"serde_json",
"sha2 0.11.0",
"sqlx",
"thiserror",
"tokio",
"tracing",
]
@ -1653,11 +1583,13 @@ name = "rust_lib_twonly"
version = "0.1.0"
dependencies = [
"flutter_rust_bridge",
"parking_lot",
"pretty_env_logger",
"prost-build",
"protocols",
"rand 0.10.1",
"sqlx",
"tempfile",
"thiserror",
"tokio",
"tracing",
@ -1669,15 +1601,6 @@ version = "0.1.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b50b8869d9fc858ce7266cce0194bd74df58b9d0e3f6df3a9fc8eb470d95c09d"
[[package]]
name = "rustc_version"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "1.1.4"

View file

@ -22,9 +22,11 @@ tokio = { version = "1.44", features = ["full"] }
tracing = "0.1.44"
rand = "0.10.1"
protocols = { path = "../protocols" }
parking_lot = "0.12.5"
[dev-dependencies]
pretty_env_logger = "0.5.0"
tempfile = "3.27.0"
[build-dependencies]

View file

@ -1,3 +1,4 @@
use protocols::user_discovery::error::UserDiscoveryError;
use thiserror::Error;
pub type Result<T> = core::result::Result<T, TwonlyError>;
@ -8,6 +9,12 @@ pub enum TwonlyError {
Initialization,
#[error("Could not find the given database")]
DatabaseNotFound,
#[error("sqlx error")]
#[error("{0}")]
SqliteError(#[from] sqlx::Error),
}
impl From<TwonlyError> for UserDiscoveryError {
fn from(error: TwonlyError) -> Self {
UserDiscoveryError::Store(error.to_string())
}
}

View file

@ -1,19 +1,24 @@
#![allow(unexpected_cfgs)]
pub mod error;
mod user_discovery;
mod user_discovery_utils;
use crate::bridge::user_discovery_utils::UserDiscoveryUtilsFlutter;
use crate::database::contact::Contact;
use crate::database::Database;
use crate::user_discovery_store::UserDiscoveryDatabaseStore;
use crate::utils::Shared;
use error::Result;
use error::TwonlyError;
use flutter_rust_bridge::frb;
use protocols::user_discovery::UserDiscovery;
use std::sync::Arc;
use tokio::sync::OnceCell;
use protocols::user_discovery::traits::OtherPromotion;
pub use protocols::user_discovery::traits::OtherPromotion;
#[frb(mirror(OtherPromotion))]
pub struct _OtherPromotion {
pub promotion_id: u32,
pub public_id: u64,
pub public_id: i64,
pub from_contact_id: i64,
pub threshold: u8,
pub announcement_share: Vec<u8>,
@ -22,26 +27,33 @@ pub struct _OtherPromotion {
pub struct TwonlyConfig {
pub database_path: String,
pub data_directory: String,
}
pub(crate) struct Twonly {
#[allow(dead_code)]
pub(crate) config: TwonlyConfig,
pub(crate) database: Arc<Database>,
pub(crate) user_discovery:
Shared<Option<UserDiscovery<UserDiscoveryDatabaseStore, UserDiscoveryUtilsFlutter>>>,
}
static GLOBAL_TWONLY: OnceCell<Twonly> = OnceCell::const_new();
fn get_instance() -> Result<&'static Twonly> {
pub(crate) fn get_workspace() -> Result<&'static Twonly> {
GLOBAL_TWONLY.get().ok_or(TwonlyError::Initialization)
}
pub async fn initialize_twonly(config: TwonlyConfig) -> Result<()> {
println!("initialized twonly");
tracing::debug!("Initialized twonly workspace.");
let twonly_res: Result<&'static Twonly> = GLOBAL_TWONLY
.get_or_try_init(|| async {
let database = Arc::new(Database::new(&config.database_path).await?);
Ok(Twonly { config, database })
Ok(Twonly {
config,
database,
user_discovery: Shared::default(),
})
})
.await;
@ -51,7 +63,7 @@ pub async fn initialize_twonly(config: TwonlyConfig) -> Result<()> {
}
pub async fn get_all_contacts() -> Result<Vec<Contact>> {
let twonly = get_instance()?;
let twonly = get_workspace()?;
Contact::get_all_contacts(twonly.database.as_ref()).await
}
@ -61,35 +73,90 @@ pub fn load_promotions() -> OtherPromotion {
#[cfg(test)]
pub(crate) mod tests {
use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
use tempfile::{NamedTempFile, TempDir};
use tokio::sync::OnceCell;
use crate::{database::Database, utils::Shared};
use super::error::Result;
use super::Twonly;
use std::path::PathBuf;
use std::{path::PathBuf, sync::Arc};
use tokio::sync::Mutex;
use super::{get_instance, initialize_twonly, TwonlyConfig};
use super::{get_workspace, initialize_twonly, TwonlyConfig};
pub(crate) async fn initialize_twonly_for_testing() -> Result<&'static Twonly> {
static TWONLY_TESTING: [OnceCell<Twonly>; 10] = [
OnceCell::const_new(),
OnceCell::const_new(),
OnceCell::const_new(),
OnceCell::const_new(),
OnceCell::const_new(),
OnceCell::const_new(),
OnceCell::const_new(),
OnceCell::const_new(),
OnceCell::const_new(),
OnceCell::const_new(),
];
static TWONLY_TESTING_INDEX: OnceCell<Arc<Mutex<usize>>> = OnceCell::const_new();
pub(crate) async fn initialize_twonly_for_testing(use_global: bool) -> Result<&'static Twonly> {
let default_twonly_database = PathBuf::from("tests/testing.db");
if !default_twonly_database.is_file() {
panic!("{} not found!", default_twonly_database.display())
}
let copied_twonly_database = default_twonly_database
.parent()
.unwrap()
.join("tmp_testing.db");
if copied_twonly_database.exists() {
std::fs::remove_file(&copied_twonly_database).unwrap();
let temp_file = NamedTempFile::new().unwrap().path().to_owned();
tracing::info!("Crated db copy: {}", temp_file.display());
let conn = SqlitePoolOptions::new()
.connect_with(
format!("sqlite://{}", default_twonly_database.display())
.parse::<SqliteConnectOptions>()
.unwrap(),
)
.await
.unwrap();
let path_str = temp_file.display().to_string();
sqlx::query("VACUUM INTO $1")
.bind(path_str)
.execute(&conn)
.await
.expect("Failed to backup database");
let tmp_dir = TempDir::new().unwrap().path().to_owned();
std::fs::create_dir_all(&tmp_dir).unwrap();
let config = TwonlyConfig {
database_path: temp_file.display().to_string(),
data_directory: tmp_dir.to_str().unwrap().to_string(),
};
if use_global {
initialize_twonly(config).await.unwrap();
get_workspace()
} else {
let index = TWONLY_TESTING_INDEX
.get_or_init(|| async { Arc::default() })
.await;
let mut index = index.lock().await;
let res: Result<&'static Twonly> = TWONLY_TESTING[*index]
.get_or_try_init(|| async {
let database = Arc::new(Database::new(&config.database_path).await?);
Ok(Twonly {
config,
database,
user_discovery: Shared::default(),
})
})
.await;
tracing::debug!("TWONLY_TESTING_INDEX: {index}");
*index += 1;
res
}
std::fs::copy(default_twonly_database, &copied_twonly_database).unwrap();
initialize_twonly(TwonlyConfig {
database_path: copied_twonly_database.display().to_string(),
})
.await
.unwrap();
get_instance()
}
}

View file

@ -1,72 +0,0 @@
use protocols::user_discovery::error::{Result, UserDiscoveryError};
use protocols::user_discovery::traits::{AnnouncedUser, OtherPromotion, UserDiscoveryStore};
use protocols::user_discovery::UserID;
use std::collections::HashMap;
struct UserDiscoveryDatabaseStore {}
impl UserDiscoveryStore for UserDiscoveryDatabaseStore {
fn get_config(&self) -> Result<Vec<u8>> {
todo!()
}
fn update_config(&self, update: Vec<u8>) -> Result<()> {
todo!()
}
fn set_shares(&self, shares: Vec<Vec<u8>>) -> Result<()> {
todo!()
}
fn get_share_for_contact(&self, contact_id: UserID) -> Result<Vec<u8>> {
todo!()
}
fn push_own_promotion(
&self,
contact_id: UserID,
version: u32,
promotion: Vec<u8>,
) -> Result<()> {
todo!()
}
fn get_own_promotions_after_version(&self, version: u32) -> Result<Vec<Vec<u8>>> {
todo!()
}
fn store_other_promotion(&self, promotion: OtherPromotion) -> Result<()> {
todo!()
}
fn get_other_promotions_by_public_id(&self, public_id: u64) -> Vec<OtherPromotion> {
todo!()
}
fn get_announced_user_by_public_id(&self, public_id: u64) -> Result<Option<AnnouncedUser>> {
todo!()
}
fn push_new_user_relation(
&self,
from_contact_id: UserID,
announced_user: AnnouncedUser,
public_key_verified_timestamp: Option<i64>,
) -> Result<()> {
todo!()
}
fn get_all_announced_users(
&self,
) -> Result<HashMap<AnnouncedUser, Vec<(UserID, Option<i64>)>>> {
todo!()
}
fn get_contact_version(&self, contact_id: UserID) -> Result<Option<Vec<u8>>> {
todo!()
}
fn set_contact_version(&self, contact_id: UserID, update: Vec<u8>) -> Result<()> {
todo!()
}
}

View file

@ -0,0 +1,27 @@
use protocols::user_discovery::error::Result;
use protocols::user_discovery::traits::UserDiscoveryUtils;
pub(crate) struct UserDiscoveryUtilsFlutter {}
impl UserDiscoveryUtils for UserDiscoveryUtilsFlutter {
async fn sign_data(&self, input_data: &[u8]) -> Result<Vec<u8>> {
todo!()
}
async fn verify_signature(
&self,
input_data: &[u8],
pubkey: &[u8],
signature: &[u8],
) -> Result<bool> {
todo!()
}
async fn verify_stored_pubkey(
&self,
from_contact_id: protocols::user_discovery::UserID,
pubkey: &[u8],
) -> Result<bool> {
todo!()
}
}

View file

@ -43,11 +43,11 @@ mod tests {
#[tokio::test]
async fn test_get_all_contacts() {
let twonly = initialize_twonly_for_testing().await.unwrap();
let twonly = initialize_twonly_for_testing(true).await.unwrap();
let contacts = Contact::get_all_contacts(&twonly.database).await.unwrap();
let users = vec!["alice", "bob", "charlie", "diana", "eve", "frank", "grace"];
let users = vec!["alice", "bob", "charlie", "david", "frank"];
assert_eq!(contacts.len(), users.len());

View file

@ -156,7 +156,7 @@ fn wire__crate__bridge__load_promotions_impl(
const _: fn() = || {
let OtherPromotion = None::<crate::bridge::OtherPromotion>.unwrap();
let _: u32 = OtherPromotion.promotion_id;
let _: u64 = OtherPromotion.public_id;
let _: i64 = OtherPromotion.public_id;
let _: i64 = OtherPromotion.from_contact_id;
let _: u8 = OtherPromotion.threshold;
let _: Vec<u8> = OtherPromotion.announcement_share;
@ -241,7 +241,7 @@ impl SseDecode for crate::bridge::OtherPromotion {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self {
let mut var_promotionId = <u32>::sse_decode(deserializer);
let mut var_publicId = <u64>::sse_decode(deserializer);
let mut var_publicId = <i64>::sse_decode(deserializer);
let mut var_fromContactId = <i64>::sse_decode(deserializer);
let mut var_threshold = <u8>::sse_decode(deserializer);
let mut var_announcementShare = <Vec<u8>>::sse_decode(deserializer);
@ -261,8 +261,10 @@ impl SseDecode for crate::bridge::TwonlyConfig {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self {
let mut var_databasePath = <String>::sse_decode(deserializer);
let mut var_dataDirectory = <String>::sse_decode(deserializer);
return crate::bridge::TwonlyConfig {
database_path: var_databasePath,
data_directory: var_dataDirectory,
};
}
}
@ -274,13 +276,6 @@ impl SseDecode for u32 {
}
}
impl SseDecode for u64 {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self {
deserializer.cursor.read_u64::<NativeEndian>().unwrap()
}
}
impl SseDecode for u8 {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_decode(deserializer: &mut flutter_rust_bridge::for_generated::SseDeserializer) -> Self {
@ -389,7 +384,11 @@ impl flutter_rust_bridge::IntoIntoDart<FrbWrapper<crate::bridge::OtherPromotion>
// Codec=Dco (DartCObject based), see doc to use other codecs
impl flutter_rust_bridge::IntoDart for crate::bridge::TwonlyConfig {
fn into_dart(self) -> flutter_rust_bridge::for_generated::DartAbi {
[self.database_path.into_into_dart().into_dart()].into_dart()
[
self.database_path.into_into_dart().into_dart(),
self.data_directory.into_into_dart().into_dart(),
]
.into_dart()
}
}
impl flutter_rust_bridge::for_generated::IntoDartExceptPrimitive for crate::bridge::TwonlyConfig {}
@ -464,7 +463,7 @@ impl SseEncode for crate::bridge::OtherPromotion {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {
<u32>::sse_encode(self.promotion_id, serializer);
<u64>::sse_encode(self.public_id, serializer);
<i64>::sse_encode(self.public_id, serializer);
<i64>::sse_encode(self.from_contact_id, serializer);
<u8>::sse_encode(self.threshold, serializer);
<Vec<u8>>::sse_encode(self.announcement_share, serializer);
@ -476,6 +475,7 @@ impl SseEncode for crate::bridge::TwonlyConfig {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {
<String>::sse_encode(self.database_path, serializer);
<String>::sse_encode(self.data_directory, serializer);
}
}
@ -486,13 +486,6 @@ impl SseEncode for u32 {
}
}
impl SseEncode for u64 {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {
serializer.cursor.write_u64::<NativeEndian>(self).unwrap();
}
}
impl SseEncode for u8 {
// Codec=Sse (Serialization based), see doc to use other codecs
fn sse_encode(self, serializer: &mut flutter_rust_bridge::for_generated::SseSerializer) {

View file

@ -1,3 +1,5 @@
pub mod bridge;
mod database;
mod frb_generated;
mod user_discovery_store;
mod utils;

View file

@ -0,0 +1,349 @@
#[allow(async_fn_in_trait)]
use protocols::user_discovery::error::{Result, UserDiscoveryError};
use protocols::user_discovery::traits::{AnnouncedUser, OtherPromotion, UserDiscoveryStore};
use protocols::user_discovery::UserID;
use sqlx::{QueryBuilder, Row, Sqlite, Transaction};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use crate::bridge::error::TwonlyError;
use crate::bridge::{get_workspace, Twonly};
#[derive(Clone)]
pub(crate) struct UserDiscoveryDatabaseStore {
ws: Arc<&'static Twonly>,
}
impl UserDiscoveryStore for UserDiscoveryDatabaseStore {
async fn new() -> Self {
#[cfg(test)]
return Self {
ws: Arc::new(
crate::bridge::tests::initialize_twonly_for_testing(false)
.await
.unwrap(),
),
};
#[allow(unreachable_code)]
Self {
ws: Arc::new(get_workspace().unwrap()),
}
}
async fn get_config(&self) -> Result<String> {
let config_path =
PathBuf::from(&self.ws.config.data_directory).join("user_discovery_config.json");
if !config_path.is_file() {
return Err(UserDiscoveryError::NotInitialized);
}
tracing::debug!("Loading Config from {}", config_path.display());
Ok(std::fs::read_to_string(&config_path)?)
}
async fn update_config(&self, update: String) -> Result<()> {
tracing::debug!("Updating configuration file.");
let config_path =
PathBuf::from(&self.ws.config.data_directory).join("user_discovery_config.json");
std::fs::write(config_path, &update)?;
Ok(())
}
async fn set_shares(&self, shares: Vec<Vec<u8>>) -> Result<()> {
let mut query_builder: QueryBuilder<Sqlite> =
QueryBuilder::new("INSERT INTO user_discovery_shares (share) ");
query_builder.push_values(shares, |mut b, share| {
b.push_bind(share);
});
query_builder
.build()
.execute(&self.ws.database.pool)
.await
.map_err(TwonlyError::from)?;
Ok(())
}
async fn get_share_for_contact(&self, contact_id: UserID) -> Result<Vec<u8>> {
let mut tx: Transaction<'_, Sqlite> = self
.ws
.database
.pool
.begin()
.await
.map_err(TwonlyError::from)?;
// 1. Check if the user already has a share assigned
let existing: Option<Vec<u8>> =
sqlx::query_scalar("SELECT share FROM user_discovery_shares WHERE contact_id = ?")
.bind(contact_id)
.fetch_optional(&mut *tx)
.await
.map_err(TwonlyError::from)?;
if let Some(share) = existing {
tx.commit().await.map_err(TwonlyError::from)?;
return Ok(share);
}
// 2. No share found. Try to assign an available one (where contact_id is NULL)
let rows_affected = sqlx::query(
"UPDATE user_discovery_shares
SET contact_id = ?
WHERE share_id = (
SELECT share_id FROM user_discovery_shares
WHERE contact_id IS NULL
LIMIT 1
)",
)
.bind(contact_id)
.execute(&mut *tx)
.await
.map_err(TwonlyError::from)?
.rows_affected();
if rows_affected == 0 {
return Err(UserDiscoveryError::NoSharesLeft);
}
// 3. Retrieve the newly assigned share
let assigned_share: Vec<u8> =
sqlx::query_scalar("SELECT share FROM user_discovery_shares WHERE contact_id = ?")
.bind(contact_id)
.fetch_one(&mut *tx)
.await
.map_err(TwonlyError::from)?;
tx.commit().await.map_err(TwonlyError::from)?;
Ok(assigned_share)
}
async fn push_own_promotion(
&self,
contact_id: UserID,
version: u32,
promotion: Vec<u8>,
) -> Result<()> {
sqlx::query(
r#"
INSERT INTO user_discovery_own_promotions (contact_id, promotion, version_id)
VALUES (?1, ?2, ?3)
"#,
)
.bind(contact_id)
.bind(promotion)
.bind(version as i64) // SQLite integers are usually i32/i64
.execute(&self.ws.database.pool)
.await
.map_err(TwonlyError::from)?;
Ok(())
}
async fn get_own_promotions_after_version(&self, version: u32) -> Result<Vec<Vec<u8>>> {
let promotions: Vec<Vec<u8>> = sqlx::query_scalar(
"SELECT promotion FROM user_discovery_own_promotions
WHERE version_id > ?
ORDER BY version_id ASC",
)
.bind(version as i64)
.fetch_all(&self.ws.database.pool)
.await
.map_err(TwonlyError::from)?;
Ok(promotions)
}
async fn store_other_promotion(&self, promotion: OtherPromotion) -> Result<()> {
sqlx::query(
r"
INSERT INTO user_discovery_other_promotions (
from_contact_id,
promotion_id,
public_id,
threshold,
announcement_share,
public_key_verified_timestamp
)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
",
)
.bind(promotion.from_contact_id)
.bind(promotion.promotion_id as i64)
.bind(promotion.public_id)
.bind(promotion.threshold as i64)
.bind(promotion.announcement_share)
.bind(promotion.public_key_verified_timestamp) // Option<i64> maps to NULLable
.execute(&self.ws.database.pool)
.await
.map_err(TwonlyError::from)?;
Ok(())
}
async fn get_other_promotions_by_public_id(
&self,
public_id: i64,
) -> Result<Vec<OtherPromotion>> {
let promotions = sqlx::query_as::<_, OtherPromotion>(
"SELECT * FROM user_discovery_other_promotions WHERE public_id = ?",
)
.bind(public_id)
.fetch_all(&self.ws.database.pool)
.await
.map_err(TwonlyError::from)?;
Ok(promotions)
}
async fn get_announced_user_by_public_id(
&self,
public_id: i64,
) -> Result<Option<AnnouncedUser>> {
let row = sqlx::query("SELECT * FROM user_discovery_announced_users WHERE public_id = ?")
.bind(public_id)
.fetch_optional(&self.ws.database.pool)
.await
.map_err(TwonlyError::from)?;
match row {
Some(r) => Ok(Some(AnnouncedUser {
user_id: r.get::<i64, _>("announced_user_id"),
public_key: r.get::<Vec<u8>, _>("announced_public_key"),
public_id: r.get::<i64, _>("public_id"),
})),
None => Ok(None),
}
}
async fn push_new_user_relation(
&self,
from_contact_id: UserID,
announced_user: AnnouncedUser,
public_key_verified_timestamp: Option<i64>,
) -> Result<()> {
let mut tx = self
.ws
.database
.pool
.begin()
.await
.map_err(TwonlyError::from)?;
sqlx::query(
r"
INSERT INTO user_discovery_announced_users (announced_user_id, announced_public_key, public_id)
VALUES (?1, ?2, ?3)
ON CONFLICT(announced_user_id) DO NOTHING
")
.bind(announced_user.user_id)
.bind(announced_user.public_key)
.bind(announced_user.public_id)
.execute(&mut *tx)
.await.map_err(TwonlyError::from)?;
if from_contact_id != announced_user.user_id {
tracing::debug!(
"INSERTING THAT {} KNOWS {}",
from_contact_id,
announced_user.user_id
);
sqlx::query(
r"INSERT INTO user_discovery_user_relations (
announced_user_id,
from_contact_id,
public_key_verified_timestamp
)
VALUES (?1, ?2, ?3)
ON CONFLICT(announced_user_id, from_contact_id) DO UPDATE SET
public_key_verified_timestamp = excluded.public_key_verified_timestamp
",
)
.bind(announced_user.user_id)
.bind(from_contact_id)
.bind(public_key_verified_timestamp)
.execute(&mut *tx)
.await
.map_err(TwonlyError::from)?;
}
tx.commit().await.map_err(TwonlyError::from)?;
Ok(())
}
async fn get_all_announced_users(
&self,
) -> Result<HashMap<AnnouncedUser, Vec<(UserID, Option<i64>)>>> {
let rows = sqlx::query(
r#"
SELECT
u.announced_user_id,
u.announced_public_key,
u.public_id,
r.from_contact_id,
r.public_key_verified_timestamp
FROM user_discovery_announced_users u
LEFT JOIN user_discovery_user_relations r
ON u.announced_user_id = r.announced_user_id
ORDER BY u.announced_user_id
"#,
)
.fetch_all(&self.ws.database.pool)
.await
.map_err(TwonlyError::from)?;
let mut map: HashMap<AnnouncedUser, Vec<(UserID, Option<i64>)>> = HashMap::new();
for row in rows {
let announced_user = AnnouncedUser {
user_id: row.get::<i64, _>("announced_user_id"),
public_key: row.get::<Vec<u8>, _>("announced_public_key"),
public_id: row.get::<i64, _>("public_id"),
};
let relations_list = map.entry(announced_user).or_insert_with(Vec::new);
// SQLX returns NULL for columns in a LEFT JOIN where no match is found.
if let Ok(Some(contact_id)) = row.try_get::<Option<i64>, _>("from_contact_id") {
let timestamp = row.get::<Option<i64>, _>("public_key_verified_timestamp");
relations_list.push((contact_id, timestamp));
}
}
Ok(map)
}
async fn get_contact_version(&self, contact_id: UserID) -> Result<Option<Vec<u8>>> {
let version: Option<Vec<u8>> =
sqlx::query_scalar("SELECT user_discovery_version FROM contacts WHERE user_id = ?")
.bind(contact_id)
.fetch_optional(&self.ws.database.pool)
.await
.map_err(TwonlyError::from)?;
Ok(version)
}
async fn set_contact_version(&self, contact_id: UserID, update: Vec<u8>) -> Result<()> {
sqlx::query("UPDATE contacts SET user_discovery_version = ? WHERE user_id = ?")
.bind(update)
.bind(contact_id)
.execute(&self.ws.database.pool)
.await
.map_err(TwonlyError::from)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::user_discovery_store::UserDiscoveryDatabaseStore;
use protocols::user_discovery::tests::test_initialize_user_discovery;
#[tokio::test]
async fn test_initialize_user_discovery_database_store() {
let _ = pretty_env_logger::try_init();
test_initialize_user_discovery::<UserDiscoveryDatabaseStore>().await;
}
}

22
rust/core/src/utils.rs Normal file
View file

@ -0,0 +1,22 @@
use parking_lot::{RwLock, RwLockReadGuard};
use std::sync::Arc;
#[derive(Default, Clone)]
pub(crate) struct Shared<T>(Arc<RwLock<T>>);
impl<T> Shared<T>
where
T: Clone,
{
pub(crate) fn new(value: T) -> Self {
Self(Arc::new(RwLock::new(value)))
}
pub(crate) fn get(&self) -> RwLockReadGuard<'_, T> {
self.0.read()
}
pub(crate) fn cloned(&self) -> T {
self.0.read().clone()
}
pub(crate) fn set(&self, value: T) {
*self.0.write() = value;
}
}

Binary file not shown.

View file

@ -13,10 +13,14 @@ serde = "1.0.228"
prost = "0.14.1"
rand = "0.10.1"
blahaj = "0.6.0"
postcard = { version = "1.1.3", features = ["alloc"] }
serde_json = "1.0"
base64 = "0.22.1"
hmac = "0.13.0"
sha2 = "0.11.0"
tokio = { version = "1.44", features = ["full"] }
sqlx = { version = "0.9.0-alpha.1", default-features = false, features = [
"derive",
] }
[dev-dependencies]
pretty_env_logger = "0.5.0"

View file

@ -1,10 +1,11 @@
pub mod error;
pub mod stores;
pub mod tests;
pub mod traits;
use std::collections::HashMap;
use std::u8;
use blahaj::{Share, Sharks};
use postcard::{from_bytes, to_allocvec};
use prost::Message;
use serde::{Deserialize, Serialize};
use crate::user_discovery::error::{Result, UserDiscoveryError};
@ -14,10 +15,6 @@ use crate::user_discovery::user_discovery_message::user_discovery_promotion::Ann
use crate::user_discovery::user_discovery_message::user_discovery_promotion::announcement_share_decrypted::SignedData;
pub use traits::UserDiscoveryStore;
#[cfg(test)]
static TRANSMITTED_NETWORK_BYTES: std::sync::OnceLock<std::sync::Mutex<usize>> =
std::sync::OnceLock::new();
/// Type of the user id, this must be consistent with the user id defined in
/// the types.proto
pub type UserID = i64;
@ -29,13 +26,13 @@ struct UserDiscoveryConfig {
/// The number of required shares to get the secret
threshold: u8,
/// Currently limited to <= 255 as GF 256 is used
total_number_of_shares: usize,
total_number_of_shares: u8,
/// Version of announcements
announcement_version: u32,
/// Version of promotions
promotion_version: u32,
/// This is a random public_id associated with a single announcement.
public_id: u64,
public_id: i64,
/// Verification shares
verification_shares: Vec<Vec<u8>>,
// The users' id:
@ -77,15 +74,15 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
/// * `Ok(())` - If the user discovery was initialized or updated successfully
/// * `Err(UserDiscoveryError)` - If the user discovery was not initialized or updated successfully
///
pub fn initialize_or_update(
pub async fn initialize_or_update(
&self,
threshold: u8,
user_id: UserID,
public_key: Vec<u8>,
) -> Result<()> {
let mut config = match self.store.get_config() {
let mut config = match self.store.get_config().await {
Ok(config) => {
let mut config: UserDiscoveryConfig = from_bytes(&config)?;
let mut config: UserDiscoveryConfig = serde_json::from_str(&config)?;
config.threshold = threshold;
config
}
@ -104,15 +101,19 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
public_key,
};
let signature = self.utils.sign_data(&signed_data.encode_to_vec())?;
let signature = self.utils.sign_data(&signed_data.encode_to_vec()).await?;
let verification_shares = self.setup_announcements(&config, signed_data, signature)?;
let verification_shares = self
.setup_announcements(&config, signed_data, signature)
.await?;
config.public_id = public_id;
config.announcement_version += 1;
config.verification_shares = verification_shares;
self.store.update_config(to_allocvec(&config)?)?;
self.store
.update_config(serde_json::to_string_pretty(&config)?)
.await?;
Ok(())
}
@ -131,8 +132,8 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
/// * `Ok(Vec<u8>)` - The current version of the user discovery
/// * `Err(UserDiscoveryError)` - If there where errors in the store.
///
pub fn get_current_version(&self) -> Result<Vec<u8>> {
let config = self.get_config()?;
pub async fn get_current_version(&self) -> Result<Vec<u8>> {
let config = self.get_config().await?;
Ok(UserDiscoveryVersion {
announcement: config.announcement_version,
promotion: config.promotion_version,
@ -148,10 +149,10 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
/// * `Ok(HashMap<AnnouncedUser, Vec<(UserID, Option<i64>)>>)` - All connections the user has discovered
/// * `Err(UserDiscoveryError)` - If there where erros in the store.
///
pub fn get_all_announced_users(
pub async fn get_all_announced_users(
&self,
) -> Result<HashMap<AnnouncedUser, Vec<(UserID, Option<i64>)>>> {
self.store.get_all_announced_users()
self.store.get_all_announced_users().await
}
///
@ -167,7 +168,7 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
/// * `Ok(Vec<Vec<u8>>)` - The new user discovery messages
/// * `Err(UserDiscoveryError)` - If there where errors in the store or if the received version is invalid.
///
pub fn get_new_messages(
pub async fn get_new_messages(
&self,
contact_id: UserID,
received_version: &[u8],
@ -175,7 +176,7 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
let mut messages = vec![];
let received_version = UserDiscoveryVersion::decode(received_version)?;
let config = self.get_config()?;
let config = self.get_config().await?;
let version = Some(UserDiscoveryVersion {
announcement: config.announcement_version,
promotion: config.promotion_version,
@ -184,7 +185,7 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
if received_version.announcement < config.announcement_version {
tracing::info!("New announcement message available for {}", contact_id);
let announcement_share = self.store.get_share_for_contact(contact_id)?;
let announcement_share = self.store.get_share_for_contact(contact_id).await?;
let user_discovery_announcement = Some(UserDiscoveryAnnouncement {
public_id: config.public_id,
@ -206,16 +207,10 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
tracing::info!("New promotion message available for user {}", contact_id);
let promoting_messages = self
.store
.get_own_promotions_after_version(received_version.promotion)?;
.get_own_promotions_after_version(received_version.promotion)
.await?;
messages.extend_from_slice(&promoting_messages);
}
#[cfg(test)]
{
let mut count = TRANSMITTED_NETWORK_BYTES.get().unwrap().lock().unwrap();
for message in &messages {
*count += message.len();
}
}
Ok(messages)
}
@ -232,27 +227,37 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
/// * `Ok(bool)` - True if the user has new announcements
/// * `Err(UserDiscoveryError)` - If there where errors in the store or if the received version is invalid.
///
pub fn should_request_new_messages(&self, contact_id: UserID, version: &[u8]) -> Result<bool> {
pub async fn should_request_new_messages(
&self,
contact_id: UserID,
version: &[u8],
) -> Result<bool> {
let received_version = UserDiscoveryVersion::decode(version)?;
let stored_version = match self.store.get_contact_version(contact_id)? {
let stored_version = match self.store.get_contact_version(contact_id).await? {
Some(buf) => UserDiscoveryVersion::decode(buf.as_slice())?,
None => UserDiscoveryVersion {
announcement: 0,
promotion: 0,
},
};
tracing::debug!(
received.announcement = %received_version.announcement,
received.promotion = %received_version.promotion,
stored.announcement = %stored_version.announcement,
stored.promotion = %stored_version.promotion,
"Comparing version numbers"
);
Ok(received_version.announcement > stored_version.announcement
|| received_version.promotion > stored_version.promotion)
}
#[cfg(test)]
pub(crate) fn get_contact_version(&self, contact_id: UserID) -> Result<Option<Vec<u8>>> {
self.store.get_contact_version(contact_id)
pub(crate) async fn get_contact_version(&self, contact_id: UserID) -> Result<Option<Vec<u8>>> {
self.store.get_contact_version(contact_id).await
}
/// Returns the latest version for this discovery.
/// Before calling this function the application must sure that contact_id is qualified to be announced.
pub fn handle_user_discovery_messages(
pub async fn handle_user_discovery_messages(
&self,
contact_id: UserID,
messages: Vec<Vec<u8>>,
@ -264,22 +269,25 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
};
if let Some(uda) = message.user_discovery_announcement {
self.handle_user_discovery_announcement(contact_id, uda)?;
self.handle_user_discovery_announcement(contact_id, uda)
.await?;
} else if let Some(udp) = message.user_discovery_promotion {
self.handle_user_discovery_promotion(contact_id, udp)?;
self.handle_user_discovery_promotion(contact_id, udp)
.await?;
} else {
tracing::warn!("Got unknown user discovery messaging. Ignoring it.");
continue;
}
self.store
.set_contact_version(contact_id, version.encode_to_vec())?;
.set_contact_version(contact_id, version.encode_to_vec())
.await?;
}
Ok(())
}
fn setup_announcements(
async fn setup_announcements(
&self,
config: &UserDiscoveryConfig,
signed_data: SignedData,
@ -301,7 +309,7 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
let dealer = sharks.dealer(&encrypted_announcement);
let mut shares: Vec<Vec<u8>> = dealer
.take(config.total_number_of_shares)
.take(config.total_number_of_shares as usize)
.map(|x| Vec::from(&x))
.collect();
@ -325,16 +333,16 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
let split_index = shares.len() - (config.threshold - 1) as usize;
verification_shares.extend(shares.drain(split_index..));
self.store.set_shares(shares)?;
self.store.set_shares(shares).await?;
Ok(verification_shares)
}
fn get_config(&self) -> Result<UserDiscoveryConfig> {
Ok(from_bytes(&self.store.get_config()?)?)
async fn get_config(&self) -> Result<UserDiscoveryConfig> {
Ok(serde_json::from_str(&self.store.get_config().await?)?)
}
fn handle_user_discovery_announcement(
async fn handle_user_discovery_announcement(
&self,
contact_id: UserID,
uda: UserDiscoveryAnnouncement,
@ -378,27 +386,34 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
if !self
.utils
.verify_stored_pubkey(contact_id, &signed_data.public_key)?
.verify_stored_pubkey(contact_id, &signed_data.public_key)
.await?
{
return Err(UserDiscoveryError::MaliciousAnnouncementData(format!(
"public key does not match with stored one",
)));
}
if !self.utils.verify_signature(
&signed_data.encode_to_vec(),
&signed_data.public_key,
&asd.signature,
)? {
if !self
.utils
.verify_signature(
&signed_data.encode_to_vec(),
&signed_data.public_key,
&asd.signature,
)
.await?
{
return Err(UserDiscoveryError::MaliciousAnnouncementData(format!(
"signature invalid",
)));
}
tracing::debug!("Increased promotion version id.");
let mut config = self.get_config()?;
let mut config = self.get_config().await?;
config.promotion_version += 1;
self.store.update_config(to_allocvec(&config)?)?;
self.store
.update_config(serde_json::to_string_pretty(&config)?)
.await?;
let message = UserDiscoveryMessage {
version: Some(UserDiscoveryVersion {
@ -415,11 +430,13 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
..Default::default()
};
self.store.push_own_promotion(
contact_id,
config.promotion_version,
message.encode_to_vec(),
)?;
self.store
.push_own_promotion(
contact_id,
config.promotion_version,
message.encode_to_vec(),
)
.await?;
let announced_user = AnnouncedUser {
user_id: signed_data.user_id,
@ -428,16 +445,18 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
};
tracing::debug!(
"NEW PROMOTION: {} knows {}",
"NEW PROMOTION 3: {} knows {}",
contact_id,
announced_user.user_id
);
// User is known, so add him to thr users relations
self.store.push_new_user_relation(
contact_id,
announced_user,
None, // This flag mus be handled by the applications as this comes from an announcement.
)?;
self.store
.push_new_user_relation(
contact_id,
announced_user,
None, // This flag mus be handled by the applications as this comes from an announcement.
)
.await?;
Ok(())
}
@ -445,38 +464,45 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
}
}
fn handle_user_discovery_promotion(
async fn handle_user_discovery_promotion(
&self,
from_contact_id: UserID,
udp: UserDiscoveryPromotion,
) -> Result<()> {
tracing::debug!("Received a new UDP with public_id = {}.", &udp.public_id);
self.store.store_other_promotion(OtherPromotion {
from_contact_id,
promotion_id: udp.promotion_id,
threshold: udp.threshold as u8,
public_id: udp.public_id,
announcement_share: udp.announcement_share,
public_key_verified_timestamp: udp.public_key_verified_timestamp,
})?;
self.store
.store_other_promotion(OtherPromotion {
from_contact_id,
promotion_id: udp.promotion_id,
threshold: udp.threshold as u8,
public_id: udp.public_id,
announcement_share: udp.announcement_share,
public_key_verified_timestamp: udp.public_key_verified_timestamp,
})
.await?;
if let Some(contact) = self.store.get_announced_user_by_public_id(udp.public_id)? {
if let Some(contact) = self
.store
.get_announced_user_by_public_id(udp.public_id)
.await?
{
tracing::debug!(
"NEW PROMOTION 2: {} knows {}",
from_contact_id,
contact.user_id
);
// The user is already known, just propagate the relation ship
self.store.push_new_user_relation(
from_contact_id,
contact,
udp.public_key_verified_timestamp,
)?;
self.store
.push_new_user_relation(from_contact_id, contact, udp.public_key_verified_timestamp)
.await?;
return Ok(());
}
let promotions = self.store.get_other_promotions_by_public_id(udp.public_id);
let promotions = self
.store
.get_other_promotions_by_public_id(udp.public_id)
.await?;
if promotions.len() < udp.threshold as usize {
tracing::debug!(
@ -507,11 +533,15 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
return Ok(());
}
if !self.utils.verify_signature(
&signed_data.encode_to_vec(),
&signed_data.public_key,
&asd.signature,
)? {
if !self
.utils
.verify_signature(
&signed_data.encode_to_vec(),
&signed_data.public_key,
&asd.signature,
)
.await?
{
return Err(UserDiscoveryError::MaliciousAnnouncementData(format!(
"signature is invalid",
)));
@ -525,7 +555,7 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
public_id: udp.public_id,
};
let user_id = self.get_config()?.user_id;
let user_id = self.get_config().await?.user_id;
for promotion in promotions {
// Do not store the announcement of the users itself.
// Or in case the promotion promotes myself
@ -535,15 +565,17 @@ impl<Store: UserDiscoveryStore, Utils: UserDiscoveryUtils> UserDiscovery<Store,
continue;
}
tracing::debug!(
"NEW PROMOTION: {} knows {}",
"NEW PROMOTION: {:x} knows {:x}",
promotion.from_contact_id,
announced_user.user_id
);
self.store.push_new_user_relation(
promotion.from_contact_id,
announced_user.clone(),
promotion.public_key_verified_timestamp,
)?;
self.store
.push_new_user_relation(
promotion.from_contact_id,
announced_user.clone(),
promotion.public_key_verified_timestamp,
)
.await?;
}
}
Ok(())
@ -557,7 +589,7 @@ impl Default for UserDiscoveryConfig {
fn default() -> Self {
Self {
threshold: 2,
total_number_of_shares: 255,
total_number_of_shares: u8::MAX,
announcement_version: 0,
promotion_version: 0,
verification_shares: vec![],
@ -566,271 +598,3 @@ impl Default for UserDiscoveryConfig {
}
}
}
#[cfg(test)]
mod tests {
use std::collections::{HashMap, HashSet};
use std::vec;
use crate::user_discovery::stores::InMemoryStore;
use crate::user_discovery::traits::tests::TestingUtils;
use crate::user_discovery::{
UserDiscovery, UserDiscoveryVersion, UserID, TRANSMITTED_NETWORK_BYTES,
};
use prost::Message;
fn get_version_bytes(announcement: u32, promotion: u32) -> Vec<u8> {
UserDiscoveryVersion {
announcement,
promotion,
}
.encode_to_vec()
}
fn get_ud(user_id: usize) -> UserDiscovery<InMemoryStore, TestingUtils> {
let store = InMemoryStore::default();
let ud = UserDiscovery::new(store.to_owned(), TestingUtils::default()).unwrap();
ud.initialize_or_update(2, user_id as UserID, vec![user_id as u8; 32])
.unwrap();
let version = ud.get_current_version().unwrap();
assert_eq!(version, get_version_bytes(1, 0));
ud
}
fn assert_new_messages(
from: (usize, &UserDiscovery<InMemoryStore, TestingUtils>),
to: (usize, &UserDiscovery<InMemoryStore, TestingUtils>),
has_new_messages: bool,
) {
// From sends a message with his current version to To
let to_received_version = &from.1.get_current_version().unwrap();
assert_eq!(
to.1.should_request_new_messages(from.0 as UserID, to_received_version)
.unwrap(),
has_new_messages
);
}
fn request_and_handle_messages(
from: (usize, &UserDiscovery<InMemoryStore, TestingUtils>),
to: (usize, &UserDiscovery<InMemoryStore, TestingUtils>),
messages_count: usize,
) {
// From sends a message with his current version to To
let to_received_version = &from.1.get_current_version().unwrap();
assert_eq!(
to.1.should_request_new_messages(from.0 as UserID, to_received_version)
.unwrap(),
true
);
// As To has a older version stored he sends a request to From: Give me all messages since version.
let from_request_version_from_to =
to.1.get_contact_version(from.0 as UserID)
.unwrap()
.unwrap_or(get_version_bytes(0, 0));
let new_messages = from
.1
.get_new_messages(to.0 as UserID, &from_request_version_from_to)
.unwrap();
assert!(new_messages.len() <= messages_count);
to.1.handle_user_discovery_messages(from.0 as UserID, new_messages)
.unwrap();
assert_eq!(
to.1.should_request_new_messages(
from.0 as UserID,
&from.1.get_current_version().unwrap()
)
.unwrap(),
false
);
}
const ALICE: usize = 0;
const BOB: usize = 1;
const CHARLIE: usize = 2;
const DAVID: usize = 3;
const FRANK: usize = 4;
const TEST_USER_COUNT: usize = 5;
struct TestUsers {
names: [&'static str; TEST_USER_COUNT],
friends: [Vec<usize>; TEST_USER_COUNT],
uds: Vec<UserDiscovery<InMemoryStore, TestingUtils>>,
}
impl TestUsers {
fn get() -> Self {
let names = ["ALICE", "BOB", "CHARLIE", "DAVID", "FRANK"];
let mut uds = vec![];
for index in 0..names.len() {
uds.push(get_ud(index));
}
let friends = [
vec![BOB, CHARLIE],
vec![ALICE, CHARLIE, DAVID],
vec![ALICE, BOB, DAVID, FRANK],
vec![BOB, CHARLIE],
vec![CHARLIE],
];
Self {
names,
uds,
friends,
}
}
}
#[test]
fn test_initialize_user_discovery() {
let _ = pretty_env_logger::try_init();
let _ = TRANSMITTED_NETWORK_BYTES.get_or_init(|| std::sync::Mutex::new(0));
let users = TestUsers::get();
fn to_all_friends(from: usize, message_count: usize, users: &TestUsers) {
for friend in &users.friends[from] {
tracing::debug!("From {} to {}", users.names[from], users.names[*friend]);
if message_count == 0 {
assert_new_messages(
(from, &users.uds[from]),
(*friend, &users.uds[*friend]),
false,
);
} else {
request_and_handle_messages(
(from, &users.uds[from]),
(*friend, &users.uds[*friend]),
message_count,
);
}
}
}
let message_flows = [
// ALICE: own announcement sending to BOB and CHARLIE
(ALICE, 1),
// BOB: own announcement + promotion for ALICE
(BOB, 2),
// BOBs version should not have any new messages for his friends
(BOB, 0),
// ALICE: promotion for BOB
(ALICE, 1),
// CHARLIE: own announcement + promotion for ALICE, BOB
(CHARLIE, 3),
// DAVID: own announcement + promotion for BOB, CHARLIE
(DAVID, 3),
// BOB: promotion for CHARLIE, DAVID
(BOB, 2),
// CHARLIE: promotion for DAVID
(CHARLIE, 1),
// FRANK: own announcement + promotion for CHARLIE
(FRANK, 2),
// CHARLIE: promotion for FRANK
(CHARLIE, 1),
// ALICE: promotion for CHARLIE
(ALICE, 1),
];
for (i, (from, count)) in message_flows.into_iter().enumerate() {
tracing::debug!("MESSAGE FLOW: {i}");
to_all_friends(from, count, &users);
}
tracing::debug!("Now all users should have the newest version.");
for from in 0..TEST_USER_COUNT {
for to in &users.friends[from] {
tracing::debug!(
"Does {} has open messages for {}?",
&users.names[from],
&users.names[*to]
);
assert_new_messages((from, &users.uds[from]), (*to, &users.uds[*to]), false);
}
}
tracing::debug!("Test if all exchanges where successful.");
let announced_users_expected = [
// ALICE should now know that BOB and CHARLIE, BOB and DAVID and CHARLIE and DAVID are friends.
// Alice should also have one protected share from Frank.
(
ALICE,
vec![
(BOB, vec![CHARLIE]), // ALICE knows Bob and that CHARLIE is connected with BOB
(CHARLIE, vec![BOB]), // ALICE knows CHARLIE and that BOB is connected with CHARLIE
(DAVID, vec![BOB, CHARLIE]), // ALICE knows DAVID and that BOB and CHARLIE are connected with DAVID
],
),
(
BOB,
vec![
(ALICE, vec![CHARLIE]),
(CHARLIE, vec![ALICE, DAVID]),
(DAVID, vec![CHARLIE]),
],
),
(
CHARLIE,
vec![
(ALICE, vec![BOB]),
(BOB, vec![ALICE, DAVID]),
(DAVID, vec![BOB]),
(FRANK, vec![]),
],
),
(
DAVID,
vec![
(ALICE, vec![BOB, CHARLIE]),
(BOB, vec![CHARLIE]),
(CHARLIE, vec![BOB]),
],
),
(FRANK, vec![(CHARLIE, vec![])]),
];
for (user, announcements) in announced_users_expected {
let announced_users2 = users.uds[user].get_all_announced_users().unwrap();
let mut announced_users = HashMap::new();
for a in announced_users2 {
announced_users.insert(a.0.user_id, a.1.iter().map(|x| x.0).collect::<Vec<_>>());
}
tracing::debug!("{} knows now: {}", users.names[user], announced_users.len());
assert_eq!(announced_users.len(), announcements.len());
for (contact_id, announced_users_expected) in announcements {
let announced_users = announced_users.get(&(contact_id as i64)).unwrap();
tracing::debug!(
"{} knows now that {} has the following friends: {}",
users.names[user],
users.names[contact_id],
announced_users
.iter()
.map(|x| users.names[*x as usize])
.collect::<Vec<_>>()
.join(", ")
);
let announced_users: HashSet<i64> = announced_users.iter().cloned().collect();
let announced_users_expected: HashSet<i64> = announced_users_expected
.iter()
.cloned()
.map(|x| x as i64)
.collect();
assert_eq!(announced_users, announced_users_expected);
}
}
let count = TRANSMITTED_NETWORK_BYTES.get().unwrap().lock().unwrap();
tracing::info!("Transmitted a total of {} bytes.", *count);
}
}

View file

@ -18,7 +18,10 @@ pub enum UserDiscoveryError {
NotInitialized,
#[error("`{0}`")]
PostcardError(#[from] postcard::Error),
JsonError(#[from] serde_json::Error),
#[error("`{0}`")]
IoError(#[from] std::io::Error),
#[error("error while calculating shamirs secret shares: `{0}`")]
ShamirsSecret(String),

View file

@ -6,7 +6,7 @@ use std::sync::{Arc, Mutex};
#[derive(Default)]
pub(crate) struct Storage {
config: Option<Vec<u8>>,
config: Option<String>,
unused_shares: Vec<Vec<u8>>,
used_shares: HashMap<UserID, Vec<u8>>,
contact_versions: HashMap<UserID, Vec<u8>>,
@ -27,24 +27,27 @@ impl InMemoryStore {
}
impl UserDiscoveryStore for InMemoryStore {
fn get_config(&self) -> Result<Vec<u8>> {
async fn new() -> Self {
Self::default()
}
async fn get_config(&self) -> Result<String> {
if let Some(storage) = self.storage().config.clone() {
return Ok(storage);
}
Err(UserDiscoveryError::NotInitialized)
}
fn update_config(&self, update: Vec<u8>) -> Result<()> {
async fn update_config(&self, update: String) -> Result<()> {
self.storage().config = Some(update);
Ok(())
}
fn set_shares(&self, shares: Vec<Vec<u8>>) -> Result<()> {
async fn set_shares(&self, shares: Vec<Vec<u8>>) -> Result<()> {
self.storage().unused_shares = shares;
Ok(())
}
fn get_share_for_contact(&self, contact_id: UserID) -> Result<Vec<u8>> {
async fn get_share_for_contact(&self, contact_id: UserID) -> Result<Vec<u8>> {
let mut storage = self.storage();
if let Some(share) = storage.used_shares.get(&contact_id) {
return Ok(share.to_vec());
@ -56,16 +59,16 @@ impl UserDiscoveryStore for InMemoryStore {
Err(UserDiscoveryError::NoSharesLeft)
}
fn get_contact_version(&self, contact_id: UserID) -> Result<Option<Vec<u8>>> {
async fn get_contact_version(&self, contact_id: UserID) -> Result<Option<Vec<u8>>> {
Ok(self.storage().contact_versions.get(&contact_id).cloned())
}
fn set_contact_version(&self, contact_id: UserID, update: Vec<u8>) -> Result<()> {
async fn set_contact_version(&self, contact_id: UserID, update: Vec<u8>) -> Result<()> {
self.storage().contact_versions.insert(contact_id, update);
Ok(())
}
fn push_own_promotion(
async fn push_own_promotion(
&self,
contact_id: UserID,
version: u32,
@ -80,7 +83,7 @@ impl UserDiscoveryStore for InMemoryStore {
Ok(())
}
fn get_own_promotions_after_version(&self, version: u32) -> Result<Vec<Vec<u8>>> {
async fn get_own_promotions_after_version(&self, version: u32) -> Result<Vec<Vec<u8>>> {
let storage = self.storage();
let elements = storage.own_promotions[(version as usize)..]
.into_iter()
@ -89,21 +92,28 @@ impl UserDiscoveryStore for InMemoryStore {
Ok(elements)
}
fn store_other_promotion(&self, promotion: OtherPromotion) -> Result<()> {
async fn store_other_promotion(&self, promotion: OtherPromotion) -> Result<()> {
self.storage().other_promotions.push(promotion);
Ok(())
}
fn get_other_promotions_by_public_id(&self, public_id: u64) -> Vec<OtherPromotion> {
self.storage()
async fn get_other_promotions_by_public_id(
&self,
public_id: i64,
) -> Result<Vec<OtherPromotion>> {
Ok(self
.storage()
.other_promotions
.iter()
.filter(|other| other.public_id == public_id)
.map(OtherPromotion::to_owned)
.collect()
.collect())
}
fn get_announced_user_by_public_id(&self, public_id: u64) -> Result<Option<AnnouncedUser>> {
async fn get_announced_user_by_public_id(
&self,
public_id: i64,
) -> Result<Option<AnnouncedUser>> {
Ok(self
.storage()
.announced_users
@ -112,13 +122,13 @@ impl UserDiscoveryStore for InMemoryStore {
.map(|u| u.0.to_owned()))
}
fn get_all_announced_users(
async fn get_all_announced_users(
&self,
) -> Result<HashMap<AnnouncedUser, Vec<(UserID, Option<i64>)>>> {
Ok(self.storage().announced_users.clone())
}
fn push_new_user_relation(
async fn push_new_user_relation(
&self,
from_contact_id: UserID,
announced_user: AnnouncedUser,

View file

@ -0,0 +1,272 @@
use crate::user_discovery::traits::tests::TestingUtils;
use crate::user_discovery::{UserDiscovery, UserDiscoveryStore, UserDiscoveryVersion, UserID};
use prost::Message;
use std::collections::{HashMap, HashSet};
use std::vec;
fn get_version_bytes(announcement: u32, promotion: u32) -> Vec<u8> {
UserDiscoveryVersion {
announcement,
promotion,
}
.encode_to_vec()
}
async fn get_ud<S: UserDiscoveryStore + Clone>(user_id: usize) -> UserDiscovery<S, TestingUtils> {
let store = S::new().await;
let ud = UserDiscovery::new(store.to_owned(), TestingUtils::default()).unwrap();
ud.initialize_or_update(2, user_id as UserID, vec![user_id as u8; 32])
.await
.unwrap();
let version = ud.get_current_version().await.unwrap();
assert_eq!(version, get_version_bytes(1, 0));
ud
}
async fn assert_new_messages<S: UserDiscoveryStore>(
from: (usize, &UserDiscovery<S, TestingUtils>),
to: (usize, &UserDiscovery<S, TestingUtils>),
has_new_messages: bool,
) {
// From sends a message with his current version to To
let to_received_version = &from.1.get_current_version().await.unwrap();
assert_eq!(
to.1.should_request_new_messages(from.0 as UserID, to_received_version)
.await
.unwrap(),
has_new_messages
);
}
async fn request_and_handle_messages<S: UserDiscoveryStore>(
from: (usize, &UserDiscovery<S, TestingUtils>),
to: (usize, &UserDiscovery<S, TestingUtils>),
messages_count: usize,
) {
// From sends a message with his current version to To
let to_received_version = &from.1.get_current_version().await.unwrap();
assert_eq!(
to.1.should_request_new_messages(from.0 as UserID, to_received_version)
.await
.unwrap(),
true
);
// As To has a older version stored he sends a request to From: Give me all messages since version.
let from_request_version_from_to =
to.1.get_contact_version(from.0 as UserID)
.await
.unwrap()
.unwrap_or(get_version_bytes(0, 0));
let new_messages = from
.1
.get_new_messages(to.0 as UserID, &from_request_version_from_to)
.await
.unwrap();
assert!(new_messages.len() <= messages_count);
to.1.handle_user_discovery_messages(from.0 as UserID, new_messages)
.await
.unwrap();
assert_eq!(
to.1.should_request_new_messages(
from.0 as UserID,
&from.1.get_current_version().await.unwrap()
)
.await
.unwrap(),
false
);
}
const ALICE: usize = 0;
const BOB: usize = 1;
const CHARLIE: usize = 2;
const DAVID: usize = 3;
const FRANK: usize = 4;
const TEST_USER_COUNT: usize = 5;
struct TestUsers<S: UserDiscoveryStore> {
names: [&'static str; TEST_USER_COUNT],
friends: [Vec<usize>; TEST_USER_COUNT],
uds: Vec<UserDiscovery<S, TestingUtils>>,
}
impl<S: UserDiscoveryStore + Clone> TestUsers<S> {
async fn get() -> Self {
let names = ["ALICE", "BOB", "CHARLIE", "DAVID", "FRANK"];
let mut uds = vec![];
for index in 0..names.len() {
uds.push(get_ud(index).await);
}
let friends = [
vec![BOB, CHARLIE],
vec![ALICE, CHARLIE, DAVID],
vec![ALICE, BOB, DAVID, FRANK],
vec![BOB, CHARLIE],
vec![CHARLIE],
];
Self {
names,
uds,
friends,
}
}
}
pub async fn test_initialize_user_discovery<S: UserDiscoveryStore + Clone>() {
#[cfg(test)]
let _ = pretty_env_logger::try_init();
let users = TestUsers::<S>::get().await;
async fn to_all_friends<S: UserDiscoveryStore + Clone>(
from: usize,
message_count: usize,
users: &TestUsers<S>,
) {
for friend in &users.friends[from] {
tracing::debug!("From {} to {}", users.names[from], users.names[*friend]);
if message_count == 0 {
assert_new_messages(
(from, &users.uds[from]),
(*friend, &users.uds[*friend]),
false,
)
.await;
} else {
request_and_handle_messages(
(from, &users.uds[from]),
(*friend, &users.uds[*friend]),
message_count,
)
.await;
}
}
}
let message_flows = [
// ALICE: own announcement sending to BOB and CHARLIE
(ALICE, 1),
// BOB: own announcement + promotion for ALICE
(BOB, 2),
// BOBs version should not have any new messages for his friends
(BOB, 0),
// ALICE: promotion for BOB
(ALICE, 1),
// CHARLIE: own announcement + promotion for ALICE, BOB
(CHARLIE, 3),
// DAVID: own announcement + promotion for BOB, CHARLIE
(DAVID, 3),
// BOB: promotion for CHARLIE, DAVID
(BOB, 2),
// CHARLIE: promotion for DAVID
(CHARLIE, 1),
// FRANK: own announcement + promotion for CHARLIE
(FRANK, 2),
// CHARLIE: promotion for FRANK
(CHARLIE, 1),
// ALICE: promotion for CHARLIE
(ALICE, 1),
];
for (i, (from, count)) in message_flows.into_iter().enumerate() {
tracing::debug!("MESSAGE FLOW: {i}");
to_all_friends(from, count, &users).await;
}
tracing::debug!("Now all users should have the newest version.");
for from in 0..TEST_USER_COUNT {
for to in &users.friends[from] {
tracing::debug!(
"Does {} has open messages for {}?",
&users.names[from],
&users.names[*to]
);
assert_new_messages((from, &users.uds[from]), (*to, &users.uds[*to]), false).await;
}
}
tracing::debug!("Test if all exchanges where successful.");
let announced_users_expected = [
// ALICE should now know that BOB and CHARLIE, BOB and DAVID and CHARLIE and DAVID are friends.
// Alice should also have one protected share from Frank.
(
ALICE,
vec![
(BOB, vec![CHARLIE]), // ALICE knows Bob and that CHARLIE is connected with BOB
(CHARLIE, vec![BOB]), // ALICE knows CHARLIE and that BOB is connected with CHARLIE
(DAVID, vec![BOB, CHARLIE]), // ALICE knows DAVID and that BOB and CHARLIE are connected with DAVID
],
),
(
BOB,
vec![
(ALICE, vec![CHARLIE]),
(CHARLIE, vec![ALICE, DAVID]),
(DAVID, vec![CHARLIE]),
],
),
(
CHARLIE,
vec![
(ALICE, vec![BOB]),
(BOB, vec![ALICE, DAVID]),
(DAVID, vec![BOB]),
(FRANK, vec![]),
],
),
(
DAVID,
vec![
(ALICE, vec![BOB, CHARLIE]),
(BOB, vec![CHARLIE]),
(CHARLIE, vec![BOB]),
],
),
(FRANK, vec![(CHARLIE, vec![])]),
];
for (user, announcements) in announced_users_expected {
let announced_users2 = users.uds[user].get_all_announced_users().await.unwrap();
let mut announced_users = HashMap::new();
for a in announced_users2 {
announced_users.insert(a.0.user_id, a.1.iter().map(|x| x.0).collect::<Vec<_>>());
}
tracing::debug!("{} knows now: {}", users.names[user], announced_users.len());
assert_eq!(announced_users.len(), announcements.len());
for (contact_id, announced_users_expected) in announcements {
let announced_users = announced_users.get(&(contact_id as i64)).unwrap();
tracing::debug!(
"{} knows now that {} has the following friends: {}",
users.names[user],
users.names[contact_id],
announced_users
.iter()
.map(|x| users.names[*x as usize])
.collect::<Vec<_>>()
.join(", ")
);
let announced_users: HashSet<i64> = announced_users.iter().cloned().collect();
let announced_users_expected: HashSet<i64> = announced_users_expected
.iter()
.cloned()
.map(|x| x as i64)
.collect();
assert_eq!(announced_users, announced_users_expected);
}
}
}
#[tokio::test]
async fn test_initialize_user_discovery_in_memory_store() {
test_initialize_user_discovery::<crate::user_discovery::stores::InMemoryStore>().await;
}

View file

@ -2,11 +2,12 @@ use std::collections::HashMap;
use crate::user_discovery::error::Result;
use crate::user_discovery::UserID;
use std::future::Future;
#[derive(Clone)]
#[derive(Clone, sqlx::FromRow)]
pub struct OtherPromotion {
pub promotion_id: u32,
pub public_id: u64,
pub public_id: i64,
pub from_contact_id: UserID,
pub threshold: u8,
pub announcement_share: Vec<u8>,
@ -17,62 +18,97 @@ pub struct OtherPromotion {
pub struct AnnouncedUser {
pub user_id: UserID,
pub public_key: Vec<u8>,
pub public_id: u64,
pub public_id: i64,
}
pub trait UserDiscoveryStore {
fn get_config(&self) -> Result<Vec<u8>>;
fn update_config(&self, update: Vec<u8>) -> Result<()>;
fn set_shares(&self, shares: Vec<Vec<u8>>) -> Result<()>;
fn new() -> impl std::future::Future<Output = Self> + Send;
fn get_config(&self) -> impl Future<Output = Result<String>> + Send;
fn update_config(&self, update: String) -> impl Future<Output = Result<()>> + Send;
fn set_shares(&self, shares: Vec<Vec<u8>>) -> impl Future<Output = Result<()>> + Send;
fn get_share_for_contact(&self, contact_id: UserID) -> Result<Vec<u8>>;
fn get_share_for_contact(
&self,
contact_id: UserID,
) -> impl Future<Output = Result<Vec<u8>>> + Send;
fn push_own_promotion(
&self,
contact_id: UserID,
version: u32,
promotion: Vec<u8>,
) -> Result<()>;
) -> impl Future<Output = Result<()>> + Send;
fn get_own_promotions_after_version(&self, version: u32) -> Result<Vec<Vec<u8>>>;
fn get_own_promotions_after_version(
&self,
version: u32,
) -> impl Future<Output = Result<Vec<Vec<u8>>>> + Send;
fn store_other_promotion(&self, promotion: OtherPromotion) -> Result<()>;
fn get_other_promotions_by_public_id(&self, public_id: u64) -> Vec<OtherPromotion>;
fn store_other_promotion(
&self,
promotion: OtherPromotion,
) -> impl Future<Output = Result<()>> + Send;
fn get_other_promotions_by_public_id(
&self,
public_id: i64,
) -> impl Future<Output = Result<Vec<OtherPromotion>>> + Send;
fn get_announced_user_by_public_id(&self, public_id: u64) -> Result<Option<AnnouncedUser>>;
fn get_announced_user_by_public_id(
&self,
public_id: i64,
) -> impl Future<Output = Result<Option<AnnouncedUser>>> + Send;
fn push_new_user_relation(
&self,
from_contact_id: UserID,
announced_user: AnnouncedUser,
public_key_verified_timestamp: Option<i64>,
) -> Result<()>;
) -> impl Future<Output = Result<()>> + Send;
fn get_all_announced_users(&self)
-> Result<HashMap<AnnouncedUser, Vec<(UserID, Option<i64>)>>>;
fn get_all_announced_users(
&self,
) -> impl Future<Output = Result<HashMap<AnnouncedUser, Vec<(UserID, Option<i64>)>>>> + Send;
fn get_contact_version(&self, contact_id: UserID) -> Result<Option<Vec<u8>>>;
fn set_contact_version(&self, contact_id: UserID, update: Vec<u8>) -> Result<()>;
fn get_contact_version(
&self,
contact_id: UserID,
) -> impl Future<Output = Result<Option<Vec<u8>>>> + Send;
fn set_contact_version(
&self,
contact_id: UserID,
update: Vec<u8>,
) -> impl Future<Output = Result<()>> + Send;
}
pub trait UserDiscoveryUtils {
fn sign_data(&self, input_data: &[u8]) -> Result<Vec<u8>>;
fn verify_signature(&self, input_data: &[u8], pubkey: &[u8], signature: &[u8]) -> Result<bool>;
fn verify_stored_pubkey(&self, from_contact_id: UserID, pubkey: &[u8]) -> Result<bool>;
fn sign_data(&self, input_data: &[u8]) -> impl Future<Output = Result<Vec<u8>>> + Send;
fn verify_signature(
&self,
input_data: &[u8],
pubkey: &[u8],
signature: &[u8],
) -> impl Future<Output = Result<bool>> + Send;
fn verify_stored_pubkey(
&self,
from_contact_id: UserID,
pubkey: &[u8],
) -> impl Future<Output = Result<bool>> + Send;
}
#[cfg(test)]
pub(crate) mod tests {
use crate::user_discovery::traits::UserDiscoveryUtils;
#[derive(Default)]
pub(crate) struct TestingUtils {}
impl UserDiscoveryUtils for TestingUtils {
fn sign_data(&self, _input_data: &[u8]) -> crate::user_discovery::error::Result<Vec<u8>> {
async fn sign_data(
&self,
_input_data: &[u8],
) -> crate::user_discovery::error::Result<Vec<u8>> {
Ok(vec![0; 64])
}
fn verify_signature(
async fn verify_signature(
&self,
_data: &[u8],
_pubkey: &[u8],
@ -81,7 +117,7 @@ pub(crate) mod tests {
Ok(true)
}
fn verify_stored_pubkey(
async fn verify_stored_pubkey(
&self,
_from_contact_id: crate::user_discovery::UserID,
_pubkey: &[u8],

View file

@ -15,7 +15,7 @@ message UserDiscoveryMessage {
optional UserDiscoveryRecall user_discovery_recall = 4;
message UserDiscoveryAnnouncement {
uint64 public_id = 1;
int64 public_id = 1;
uint32 threshold = 2;
bytes announcement_share = 4;
repeated bytes verification_shares = 6;
@ -23,7 +23,7 @@ message UserDiscoveryMessage {
message UserDiscoveryPromotion {
uint32 promotion_id = 1;
uint64 public_id = 2;
int64 public_id = 2;
uint32 threshold = 3;
bytes announcement_share = 5;
@ -31,9 +31,9 @@ message UserDiscoveryMessage {
message AnnouncementShareDecrypted {
message SignedData {
uint64 public_id = 1;
int64 user_id = 2;
bytes public_key = 3;
int64 public_id = 1;
int64 user_id = 2;
bytes public_key = 3;
}
SignedData signed_data = 1;
bytes signature = 2;

View file

@ -15,6 +15,7 @@ import 'schema_v8.dart' as v8;
import 'schema_v9.dart' as v9;
import 'schema_v10.dart' as v10;
import 'schema_v11.dart' as v11;
import 'schema_v12.dart' as v12;
class GeneratedHelper implements SchemaInstantiationHelper {
@override
@ -42,10 +43,12 @@ class GeneratedHelper implements SchemaInstantiationHelper {
return v10.DatabaseAtV10(db);
case 11:
return v11.DatabaseAtV11(db);
case 12:
return v12.DatabaseAtV12(db);
default:
throw MissingSchemaException(version, versions);
}
}
static const versions = const [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
static const versions = const [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
}

File diff suppressed because it is too large Load diff

View file

@ -11,7 +11,7 @@ void main() {
late File dbFile;
setUp(() {
dbFile = File('rust/tests/testing.db');
dbFile = File('rust/core/tests/testing.db');
if (dbFile.existsSync()) {
dbFile.deleteSync();
}